model_runner.py 47.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import collections
17
import datetime
18
import gc
19
import inspect
Shuo Yang's avatar
Shuo Yang committed
20
import json
21
import logging
22
import os
23
import time
24
25
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27

import torch
28
import torch.distributed as dist
29
30
31
32
33

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
34
    get_tp_group,
35
    get_world_group,
zhyncs's avatar
zhyncs committed
36
37
    init_distributed_environment,
    initialize_model_parallel,
38
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
39
)
40
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
41
42
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
43
    get_attention_tp_size,
44
45
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
46
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
47
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
48
49
50
51
from sglang.srt.layers.quantization.deep_gemm import (
    _ENABLE_JIT_DEEPGEMM,
    update_deep_gemm_config,
)
52
from sglang.srt.layers.sampler import Sampler
53
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
54
from sglang.srt.lora.lora_manager import LoRAManager
55
from sglang.srt.managers.schedule_batch import global_server_args_dict
56
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
57
    DoubleSparseTokenToKVPool,
58
59
60
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
61
    TokenToKVPoolAllocator,
62
)
Lianmin Zheng's avatar
Lianmin Zheng committed
63
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
Yineng Zhang's avatar
Yineng Zhang committed
64
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
65
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69
70
71
72
from sglang.srt.model_loader.loader import (
    DefaultModelLoader,
    device_loading_context,
    get_model_loader,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
73
from sglang.srt.model_loader.weight_utils import default_weight_loader
74
from sglang.srt.patch_torch import monkey_patch_torch_reductions
75
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
76
from sglang.srt.server_args import ServerArgs
77
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
78
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
79
from sglang.srt.utils import (
80
    MultiprocessingSerializer,
81
    enable_show_time_cost,
82
    get_available_gpu_memory,
83
    get_bool_env_var,
84
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
85
    is_cuda,
86
    is_fa3_default_architecture,
87
    is_flashinfer_available,
HAI's avatar
HAI committed
88
    is_hip,
89
    is_hopper_with_cuda_12_3,
90
    is_no_spec_infer_or_topk_one,
91
    monkey_patch_p2p_access_check,
92
    monkey_patch_vllm_gguf_config,
93
    set_cpu_offload_max_bytes,
94
    set_cuda_arch,
95
)
96

Lianmin Zheng's avatar
Lianmin Zheng committed
97
# Use a small KV cache pool size for tests in CI
98
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
99
100

# Detect stragger ranks in model loading
101
102
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
103
104
logger = logging.getLogger(__name__)

105

Lianmin Zheng's avatar
Lianmin Zheng committed
106
class ModelRunner:
107
108
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
109
110
    def __init__(
        self,
111
        model_config: ModelConfig,
112
113
114
115
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
116
117
        pp_rank: int,
        pp_size: int,
118
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
119
        server_args: ServerArgs,
120
        is_draft_worker: bool = False,
121
122
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
123
    ):
124
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
125
126
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
127
        self.device = server_args.device
128
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
129
130
        self.tp_rank = tp_rank
        self.tp_size = tp_size
131
132
        self.pp_rank = pp_rank
        self.pp_size = pp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
133
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
134
        self.server_args = server_args
135
        self.is_draft_worker = is_draft_worker
136
137
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
138
        self.should_log = tp_rank == 0
139
140
141
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
142
        self.page_size = server_args.page_size
143
144
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Baizhou Zhang's avatar
Baizhou Zhang committed
145
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
146
        self.attention_chunk_size = model_config.attention_chunk_size
Ke Bao's avatar
Ke Bao committed
147

148
        # Model-specific adjustment
149
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
150

151
152
        if server_args.show_time_cost:
            enable_show_time_cost()
153
154

        # Global vars
155
156
        global_server_args_dict.update(
            {
157
                "attention_backend": server_args.attention_backend,
158
159
160
161
162
163
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "deepep_mode": server_args.deepep_mode,
                "device": server_args.device,
                "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
                "disable_radix_cache": server_args.disable_radix_cache,
164
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
165
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
166
                "enable_ep_moe": server_args.enable_ep_moe,
167
                "enable_deepep_moe": server_args.enable_deepep_moe,
168
                "deepep_config": server_args.deepep_config,
169
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
170
                "moe_dense_tp_size": server_args.moe_dense_tp_size,
171
                "n_share_experts_fusion": server_args.n_share_experts_fusion,
172
173
174
175
176
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
                "torchao_config": server_args.torchao_config,
                "sampling_backend": server_args.sampling_backend,
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
177
                "use_mla_backend": self.use_mla_backend,
178
                "mm_attention_backend": server_args.mm_attention_backend,
179
180
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
181

182
        # CPU offload
183
184
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

185
        # Get memory before model loading
186
        min_per_gpu_memory = self.init_torch_distributed()
187

188
189
190
191
        # Update deep gemm configure
        if _ENABLE_JIT_DEEPGEMM:
            update_deep_gemm_config(gpu_id, server_args)

Lianmin Zheng's avatar
Lianmin Zheng committed
192
        # If it is a draft model, tp_group can be different
193
194
        self.initialize(min_per_gpu_memory)

195
196
197
198
199
        # temporary cached values
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )

200
201
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
202
203
204
205
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

206
        # Load the model
207
        self.sampler = Sampler()
208
        self.load_model()
209

210
211
212
213
214
215
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(
            self.model, "end_layer", self.model_config.num_hidden_layers
        )
        self.num_effective_layers = self.end_layer - self.start_layer

216
        # Apply torchao quantization
217
218
219
220
221
222
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
223

224
        # Apply torch TP if the model supports it
225
226
227
228
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

229
        # Init lora
230
231
        if server_args.lora_paths is not None:
            self.init_lora_manager()
232
233

        # Init memory pool and attention backends
234
235
        self.init_memory_pool(
            min_per_gpu_memory,
236
            server_args.max_running_requests,
237
238
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
239
240
241
242
243
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
244
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
245
            self.init_attention_backend()
246

James Liu's avatar
James Liu committed
247
248
249
250
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
            self.model.set_eagle3_layers_to_capture()

251
252
253
    def model_specific_adjustment(self):
        server_args = self.server_args

254
        if server_args.attention_backend is None:
255
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
256
257
            Auto select the fastest attention backend.

258
259
260
261
262
263
264
265
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
                2.2 Otherwise, we will use triton backend.
            """

266
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
267
                # MHA architecture
268
                if (
269
                    is_hopper_with_cuda_12_3()
270
271
272
273
274
275
276
277
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
278
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
279
                # MLA architecture
280
                if is_hopper_with_cuda_12_3():
281
                    server_args.attention_backend = "fa3"
282
283
                else:
                    server_args.attention_backend = "triton"
284
285
286
287
            if self.should_log:
                logger.info(
                    f"Attention backend not set. Use {server_args.attention_backend} backend by default."
                )
288
        elif self.use_mla_backend:
289
            if server_args.device != "cpu":
290
291
292
293
294
                if server_args.attention_backend in [
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
295
                    "cutlass_mla",
296
                ]:
297
298
299
300
                    if self.should_log:
                        logger.info(
                            f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                        )
301
                else:
302
303
304
305
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
306
                raise ValueError("MLA optimization not supported on CPU.")
307

308
309
310
311
312
313
314
315
316
317
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

318
        if server_args.enable_double_sparsity:
319
320
321
322
            if self.should_log:
                logger.info(
                    "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
                )
323
324
325
326
327
328
329
330
331
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
Mick's avatar
Mick committed
332
            self.mem_fraction_static *= 0.90
333
334
335
336
337
338
339
340
            if self.should_log:
                logger.info(
                    f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                    f"because this is a multimodal model."
                )
                logger.info(
                    "Automatically turn off --chunked-prefill-size for multimodal model."
                )
Mick's avatar
Mick committed
341
            server_args.chunked_prefill_size = -1
342

343
344
345
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
        elif self.page_size > 1:
346
347
            if self.should_log:
                logger.info("Disable chunked prefix cache when page size > 1.")
348
349
350
            server_args.disable_chunked_prefix_cache = True

        if not server_args.disable_chunked_prefix_cache:
351
352
            if self.should_log:
                logger.info("Chunked prefix cache is turned on.")
353

354
    def init_torch_distributed(self):
355
        logger.info("Init torch distributed begin.")
356

357
358
359
360
361
362
363
364
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
365
366
        if self.device == "cuda":
            backend = "nccl"
367
        elif self.device == "xpu":
368
            backend = "xccl"
369
370
        elif self.device == "hpu":
            backend = "hccl"
371
372
        elif self.device == "cpu":
            backend = "gloo"
373
374
        elif self.device == "npu":
            backend = "hccl"
375

376
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
377
        if not self.server_args.enable_p2p_check:
378
379
            monkey_patch_p2p_access_check()

380
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
381
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
382
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
383
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
384
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
385
386

        if not self.is_draft_worker:
Mick's avatar
Mick committed
387
            # Only initialize the distributed environment on the target model worker.
388
389
            init_distributed_environment(
                backend=backend,
390
391
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
392
393
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
394
                timeout=self.server_args.dist_timeout,
395
            )
396
397
398
399
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
            )
400
401
402
403
404
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
405
                moe_dense_tp_size=self.server_args.moe_dense_tp_size,
406
                pp_size=self.server_args.pp_size,
407
            )
408

409
        min_per_gpu_memory = get_available_gpu_memory(
410
411
412
413
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
414
        )
415
        self.tp_group = get_tp_group()
416
        self.attention_tp_group = get_attention_tp_group()
417

418
        # Check memory for tensor parallelism
419
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
420
        if self.tp_size > 1:
421
            if min_per_gpu_memory < local_gpu_memory * 0.9:
422
423
424
425
426
427
428
429
430
431
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
432

433
434
435
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
436
        return min_per_gpu_memory
437

Lianmin Zheng's avatar
Lianmin Zheng committed
438
    def load_model(self):
439
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
440
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
441
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
442
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
443
444

        # This can reduce thread conflicts and speed up weight loading.
445
446
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
447
448
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
449
450
451
452
                if self.should_log:
                    logger.info(
                        "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                    )
Zhang, Liangang's avatar
Zhang, Liangang committed
453
                self.server_args.dtype = "float16"
454
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
455
456
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
457

458
459
        set_cuda_arch()

460
        # Prepare the model config
461
462
463
464
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
465
466
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
467
468

        # Load the model
469
470
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
471
472
        monkey_patch_isinstance_for_vllm_base_layer()

473
474
475
476
477
478
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
479
        monkey_patch_vllm_parallel_state(reverse=True)
480
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
481

bjmsong's avatar
bjmsong committed
482
483
484
485
486
487
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
488
489
490
491
492
                    if self.should_log:
                        logger.info(
                            "Loaded KV cache scaling factors from %s",
                            self.server_args.quantization_param_path,
                        )
bjmsong's avatar
bjmsong committed
493
494
495
496
497
498
499
500
501
502
503
504
505
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

506
        # Parse other args
507
        self.sliding_window_size = (
508
509
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
510
511
            else None
        )
512
        self.dtype = self.model_config.dtype
513

514
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
515
        logger.info(
516
            f"Load weight end. "
517
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
518
            f"dtype={self.dtype}, "
519
520
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
521
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
522

523
524
525
526
527
528
529
530
531
532
533
534
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

535
536
537
538
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
539
        logger.info(
Chayenne's avatar
Chayenne committed
540
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
541
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
542
543
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
544
        target_device = torch.device(self.device)
545
        self.model_config.model_path = model_path
546
547
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
548
        # Only support DefaultModelLoader for now
549
550
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
551
552
            message = f"Failed to get model loader: {loader}."
            return False, message
553
554
555

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
556
                DefaultModelLoader.Source.init_new(config, self.model)
557
558
559
560
            )
            return iter

        def model_load_weights(model, iter):
561
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
562
563
            return model

564
        with set_default_torch_dtype(self.model_config.dtype):
565
            try:
566
                iter = get_weight_iter(self.model_config)
567
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
568
                message = f"Failed to get weights iterator: {e}."
569
570
571
572
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
573
574
575
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
576
577
                del iter
                gc.collect()
578
                iter = get_weight_iter(self.model_config)
579
580
581
582
583
584
585
586
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

587
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
588
        return True, "Succeeded to update model weights."
589

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
618
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
683
        return True, "Success"
684

685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

702
703
704
705
706
707
708
709
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
710
            lora_backend=self.server_args.lora_backend,
711
712
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
713
714
715
        )
        logger.info("LoRA manager ready.")

716
    def profile_max_num_token(self, total_gpu_memory: int):
717
        available_gpu_memory = get_available_gpu_memory(
718
719
720
721
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
722
        )
723
        if self.use_mla_backend:
724
725
726
727
728
            num_layers = (
                self.model_config.num_hidden_layers
                if not self.is_draft_worker
                else self.model_config.hf_config.num_nextn_predict_layers
            )
729
730
            # FIXME: pipeline parallelism is not compatible with mla backend
            assert self.pp_size == 1
731
732
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
733
                * num_layers
734
                * torch._utils._element_size(self.kv_cache_dtype)
735
736
737
            )
        else:
            cell_size = (
738
                self.model_config.get_num_kv_heads(get_attention_tp_size())
739
                * self.model_config.head_dim
740
                * self.num_effective_layers
741
                * 2
742
                * torch._utils._element_size(self.kv_cache_dtype)
743
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
744
745
746
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
747
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
748
749
        return max_num_token

750
    def init_memory_pool(
751
752
        self,
        total_gpu_memory: int,
753
754
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
755
    ):
756
757
758
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
759
            if is_hip():  # Using natively supported format
HAI's avatar
HAI committed
760
761
762
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
763
764
765
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
766
767
768
769
770
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

771
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
772
773
774
775
776
777
778
779
780
781
782
783

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

784
785
786
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

787
788
789
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
790
                max_num_reqs = self.server_args.max_num_reqs
791
            else:
792
793
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
794
795
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
796
797
798
799
800
801
802
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
803
804
                    + 100
                )
805
806
807
808
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
809

810
811
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
812
                logging.warning(
813
814
815
816
817
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
818

819
820
821
822
823
824
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )

825
        if self.max_total_num_tokens <= 0:
826
            raise RuntimeError(
827
                "Not enough memory. Please try to increase --mem-fraction-static."
828
            )
829

830
831
832
833
834
835
836
837
838
839
840
        if self.req_to_token_pool is None:
            self.req_to_token_pool = ReqToTokenPool(
                size=max_num_reqs + 1,
                max_context_len=self.model_config.context_len + 4,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

841
        if self.use_mla_backend:
842
843
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
844
                page_size=self.page_size,
845
                dtype=self.kv_cache_dtype,
846
847
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
848
849
850
851
                layer_num=(
                    self.model_config.num_hidden_layers
                    if not self.is_draft_worker
                    else self.model_config.hf_config.num_nextn_predict_layers
852
                ),  # PP is not compatible with mla backend
Zhang, Liangang's avatar
Zhang, Liangang committed
853
                device=self.device,
854
                enable_memory_saver=self.server_args.enable_memory_saver,
855
856
                start_layer=self.start_layer,
                end_layer=self.end_layer,
857
            )
Shuo Yang's avatar
Shuo Yang committed
858
859
860
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
861
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
862
                dtype=self.kv_cache_dtype,
863
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
864
                head_dim=self.model_config.head_dim,
865
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
866
867
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
868
                enable_memory_saver=self.server_args.enable_memory_saver,
869
870
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
871
            )
872
873
874
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
875
                page_size=self.page_size,
876
                dtype=self.kv_cache_dtype,
877
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
878
                head_dim=self.model_config.head_dim,
879
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
880
                device=self.device,
881
                enable_memory_saver=self.server_args.enable_memory_saver,
882
883
                start_layer=self.start_layer,
                end_layer=self.end_layer,
884
            )
885
886

        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
            if self.page_size == 1:
                self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
            else:
                self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
902
903
904
        else:
            assert self.is_draft_worker

905
        logger.info(
906
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
907
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
908
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
909

Lianmin Zheng's avatar
Lianmin Zheng committed
910
911
912
913
914
915
916
917
918
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

919
920
    def init_attention_backend(self):
        """Init attention kernel backend."""
921
        if self.server_args.attention_backend == "flashinfer":
922
923
924
925
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
926

927
928
929
930
931
932
933
934
935
936
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
                    self.plan_stream_for_flashinfer = torch.cuda.Stream()
                self.attn_backend = FlashInferAttnBackend(self)
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

                self.attn_backend = FlashInferMLAAttnBackend(self)
937
938
939
940
941
942
943
944
945
946
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
947
948
949
950
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

951
                self.attn_backend = DoubleSparseAttnBackend(self)
952
            else:
953
954
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

955
956
                self.attn_backend = TritonAttnBackend(self)
        elif self.server_args.attention_backend == "torch_native":
957
958
959
960
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

961
            self.attn_backend = TorchNativeAttnBackend(self)
lukec's avatar
lukec committed
962
963
964
965
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

            self.attn_backend = FlashMLABackend(self)
966
        elif self.server_args.attention_backend == "fa3":
967
968
969
970
            assert (
                torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
            ) or torch.cuda.get_device_capability()[0] == 9, (
                "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
971
972
973
974
975
976
977
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

            self.attn_backend = FlashAttentionBackend(self)
978
979
980
981
982
983
        elif self.server_args.attention_backend == "cutlass_mla":
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

            self.attn_backend = CutlassMLABackend(self)
984
985
986
987
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
988

Shuo Yang's avatar
Shuo Yang committed
989
990
991
992
993
994
995
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

996
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
997
998
999
1000
1001
1002
1003
1004
1005
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1006
    def init_cuda_graphs(self):
1007
        """Capture cuda graphs."""
1008
1009
        self.cuda_graph_runner = None

1010
        if not self.is_generation:
1011
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1012
1013
            return

1014
1015
        if self.server_args.disable_cuda_graph:
            return
1016

1017
        tic = time.perf_counter()
1018
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1019
        logger.info(
1020
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1021
        )
1022
        self.cuda_graph_runner = CudaGraphRunner(self)
1023
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1024
        logger.info(
1025
            f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1026
            f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
1027
        )
1028

1029
    def apply_torch_tp(self):
1030
1031
        if self.should_log:
            logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1032
1033
1034
1035
1036
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1037
1038
1039
    def forward_decode(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
1040
        self.attn_backend.init_forward_metadata(forward_batch)
1041
1042
1043
1044
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1045
        return self.model.forward(
1046
            forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
Lianmin Zheng's avatar
Lianmin Zheng committed
1047
1048
        )

1049
    def forward_extend(
1050
1051
1052
1053
1054
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
1055
1056
1057
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1071

1072
1073
1074
1075
1076
1077
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
1078
        return self.model.forward(
1079
1080
1081
1082
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
1083
1084
        )

1085
    def forward(
1086
1087
1088
1089
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
1090
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1091
        can_run_cuda_graph = bool(
1092
1093
1094
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
1095
1096
        )
        if can_run_cuda_graph:
1097
            ret = self.cuda_graph_runner.replay(
1098
1099
1100
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1101
            )
1102
1103
        elif forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1104
        elif forward_batch.forward_mode.is_extend():
1105
            ret = self.forward_extend(
1106
1107
1108
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1109
            )
Ke Bao's avatar
Ke Bao committed
1110
        elif forward_batch.forward_mode.is_idle():
1111
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
1112
        else:
1113
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1114

1115
1116
        return ret, can_run_cuda_graph

1117
1118
1119
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1120
        # Apply logit bias
1121
1122
1123
1124
1125
1126
1127
1128
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1129
1130
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
1151

1152
1153
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1154
1155
1156
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1157
            forward_batch.sampling_info,
1158
1159
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1160
            forward_batch.token_ids_logprobs,
1161
        )
1162
1163
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1164
1165
1166
1167
1168
1169
1170
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
1171
1172
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1173

1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1190
1191
1192
1193
1194
1195
1196
1197
1198

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
1199
1200
1201
        monkey_patch_torch_reductions()
        tensor = tensor.get(tp_rank)
    return tensor.to(torch.cuda.current_device())
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])