model_runner.py 92.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
18
import inspect
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
23
import socket
import threading
24
import time
25
from collections import defaultdict
26
27
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29

import torch
30
import torch.distributed as dist
31

32
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
33
from sglang.srt.configs.device_config import DeviceConfig
34
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
fzyzcjy's avatar
fzyzcjy committed
35
36
37
38
39
40
from sglang.srt.configs.model_config import (
    AttentionArch,
    ModelConfig,
    get_nsa_index_head_dim,
    is_deepseek_nsa,
)
41
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
42
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
43
from sglang.srt.distributed import (
44
    get_pp_group,
zhyncs's avatar
zhyncs committed
45
    get_tp_group,
46
    get_world_group,
zhyncs's avatar
zhyncs committed
47
48
    init_distributed_environment,
    initialize_model_parallel,
49
    set_custom_all_reduce,
50
    set_mscclpp_all_reduce,
51
    set_symm_mem_all_reduce,
zhyncs's avatar
zhyncs committed
52
)
53
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
fzyzcjy's avatar
fzyzcjy committed
54
55
56
57
58
59
60
61
62
63
64
65
66
from sglang.srt.eplb.eplb_manager import EPLBManager
from sglang.srt.eplb.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
    set_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_location import (
    ExpertLocationMetadata,
    compute_initial_expert_location_metadata,
    get_global_expert_location_metadata,
    set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
67
68
69
70
from sglang.srt.layers.attention.attention_registry import (
    ATTENTION_BACKENDS,
    attn_backend_wrapper,
)
71
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
72
73
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
74
    get_attention_tp_size,
75
76
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
77
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
78
79
80
from sglang.srt.layers.quantization import (
    deep_gemm_wrapper,
    monkey_patch_isinstance_for_vllm_base_layer,
81
)
82
from sglang.srt.layers.sampler import Sampler
83
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
84
from sglang.srt.lora.lora_manager import LoRAManager
85
from sglang.srt.lora.lora_registry import LoRARef
86
87
88
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    PagedTokenToKVPoolAllocator,
tarinkk's avatar
tarinkk committed
89
    SWATokenToKVPoolAllocator,
90
91
    TokenToKVPoolAllocator,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
92
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
93
from sglang.srt.mem_cache.memory_pool import (
94
95
    AscendMLAPagedTokenToKVPool,
    AscendTokenToKVPool,
Shuo Yang's avatar
Shuo Yang committed
96
    DoubleSparseTokenToKVPool,
Yi Zhang's avatar
Yi Zhang committed
97
98
    HybridLinearKVPool,
    HybridReqToTokenPool,
99
100
    MHATokenToKVPool,
    MLATokenToKVPool,
fzyzcjy's avatar
fzyzcjy committed
101
    NSATokenToKVPool,
102
    ReqToTokenPool,
tarinkk's avatar
tarinkk committed
103
    SWAKVPool,
104
)
105
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
106
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
107
108
109
110
111
from sglang.srt.model_executor.forward_batch_info import (
    ForwardBatch,
    ForwardMode,
    PPProxyTensors,
)
112
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
113
114
115
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
    PiecewiseCudaGraphRunner,
)
116
from sglang.srt.model_loader import get_model
117
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
118
119
120
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
    trigger_init_weights_send_group_for_remote_instance_request,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
121
from sglang.srt.model_loader.utils import set_default_torch_dtype
122
from sglang.srt.model_loader.weight_utils import default_weight_loader
123
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
124
125
126
127
128
from sglang.srt.server_args import (
    ServerArgs,
    get_global_server_args,
    set_global_server_args_for_scheduler,
)
129
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
130
from sglang.srt.utils import (
131
    MultiprocessingSerializer,
132
    cpu_has_amx_support,
133
    dynamic_import,
134
    enable_show_time_cost,
135
    get_available_gpu_memory,
136
    get_bool_env_var,
137
    get_cpu_ids_by_node,
138
    init_custom_process_group,
139
    is_fa3_default_architecture,
140
    is_flashinfer_available,
HAI's avatar
HAI committed
141
    is_hip,
142
    is_hopper_with_cuda_12_3,
143
    is_no_spec_infer_or_topk_one,
144
    is_npu,
145
    is_sm100_supported,
146
    log_info_on_rank0,
147
    monkey_patch_p2p_access_check,
148
    monkey_patch_vllm_gguf_config,
149
    set_cuda_arch,
150
    slow_rank_detector,
151
)
152
153
154
155
156
from sglang.srt.utils.offloader import (
    create_offloader_from_server_args,
    get_offloader,
    set_offloader,
)
157
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
158
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
159
160
161
162
from sglang.srt.weight_sync.tensor_bucket import (
    FlattenedTensorBucket,
    FlattenedTensorMetadata,
)
163

164
165
166
167
168
169
170
171
172
173
MLA_ATTENTION_BACKENDS = [
    "aiter",
    "flashinfer",
    "fa3",
    "fa4",
    "triton",
    "flashmla",
    "cutlass_mla",
    "trtllm_mla",
    "ascend",
fzyzcjy's avatar
fzyzcjy committed
174
    "nsa",
175
176
177
178
179
180
181
182
183
]


def add_mla_attention_backend(backend_name):
    if backend_name not in MLA_ATTENTION_BACKENDS:
        MLA_ATTENTION_BACKENDS.append(backend_name)
        logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")


184
_is_hip = is_hip()
185
_is_npu = is_npu()
186
_is_cpu_amx_available = cpu_has_amx_support()
187

Lianmin Zheng's avatar
Lianmin Zheng committed
188
# Use a small KV cache pool size for tests in CI
189
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
190
191

# Detect stragger ranks in model loading
192
193
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

194
195
196
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3

Lianmin Zheng's avatar
Lianmin Zheng committed
197
198
logger = logging.getLogger(__name__)

199
200
201
202
203
204
205
if _is_npu:
    import torch_npu

    torch.npu.config.allow_internal_format = True
    torch_npu.npu.set_compile_mode(jit_compile=False)


206
207
208
209
210
211
212
213
214
215
216
217
218
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
219
class ModelRunner:
220
221
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
222
223
    def __init__(
        self,
224
        model_config: ModelConfig,
225
226
227
228
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
Cheng Wan's avatar
Cheng Wan committed
229
230
        moe_ep_rank: int,
        moe_ep_size: int,
231
232
        pp_rank: int,
        pp_size: int,
233
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
234
        server_args: ServerArgs,
fzyzcjy's avatar
fzyzcjy committed
235
        dp_rank: Optional[int] = None,
236
        is_draft_worker: bool = False,
237
        req_to_token_pool: Optional[ReqToTokenPool] = None,
238
        token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
239
    ):
240
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
241
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
242
        self.device = server_args.device
243
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
244
245
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Cheng Wan's avatar
Cheng Wan committed
246
247
        self.moe_ep_rank = moe_ep_rank
        self.moe_ep_size = moe_ep_size
248
        self.dp_size = server_args.dp_size
249
250
        self.pp_rank = pp_rank
        self.pp_size = pp_size
251
        self.model_config = model_config
Zhang, Liangang's avatar
Zhang, Liangang committed
252
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
253
        self.server_args = server_args
254
        self.is_draft_worker = is_draft_worker
255
256
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
257
258
259
        self.is_multimodal_chunked_prefill_supported = (
            model_config.is_multimodal_chunked_prefill_supported
        )
260
261
262
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
263
        self.page_size = server_args.page_size
264
265
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
tarinkk's avatar
tarinkk committed
266
        self.is_hybrid = model_config.is_hybrid
Baizhou Zhang's avatar
Baizhou Zhang committed
267
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
268
        self.attention_chunk_size = model_config.attention_chunk_size
269
270
        self.forward_pass_id = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
271
272
273
        # Apply the rank zero filter to logger
        if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
            logger.addFilter(RankZeroFilter(tp_rank == 0))
274
275
        if server_args.show_time_cost:
            enable_show_time_cost()
276

Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
279
        # Model-specific adjustment
        self.model_specific_adjustment()

280
281
282
283
284
285
        # Set the global server_args in the scheduler process
        set_global_server_args_for_scheduler(server_args)
        global_server_args = get_global_server_args()

        # FIXME: hacky set `use_mla_backend`
        global_server_args.use_mla_backend = self.use_mla_backend
Lianmin Zheng's avatar
Lianmin Zheng committed
286

287
288
289
290
        # Init OpenMP threads binding for CPU
        if self.device == "cpu":
            self.init_threads_binding()

291
        # Get memory before model loading
292
        min_per_gpu_memory = self.init_torch_distributed()
293

294
        # CPU offload
fzyzcjy's avatar
fzyzcjy committed
295
        set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
296

fzyzcjy's avatar
fzyzcjy committed
297
298
299
        if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
            slow_rank_detector.execute()

300
        # Update deep gemm configure
301
302
        if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
            deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
303

Lianmin Zheng's avatar
Lianmin Zheng committed
304
        # Initialize the model runner
305
306
        self.initialize(min_per_gpu_memory)

Lianmin Zheng's avatar
Lianmin Zheng committed
307
        # Temporary cached values
308
309
310
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
311
312

        # For weight updates
313
        self._model_update_group = {}
314
        self._weights_send_group = {}
315

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        if (
            self.server_args.enable_piecewise_cuda_graph
            and self.can_run_piecewise_cuda_graph()
        ):
            self.attention_layers = []
            for layer in self.model.model.layers:
                if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
                    self.attention_layers.append(layer.self_attn.attn)
            if len(self.attention_layers) < self.model_config.num_hidden_layers:
                # TODO(yuwei): support Non-Standard GQA
                log_info_on_rank0(
                    logger,
                    "Disable piecewise CUDA graph because some layers do not apply Standard GQA",
                )
                self.piecewise_cuda_graph_runner = None
            else:
                self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
        else:
            self.piecewise_cuda_graph_runner = None

336
337
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
338

339
340
341
342
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

343
344
345
346
347
348
349
350
        if not self.is_draft_worker:
            set_global_expert_location_metadata(
                compute_initial_expert_location_metadata(server_args, self.model_config)
            )
            if self.tp_rank == 0 and get_bool_env_var(
                "SGLANG_LOG_EXPERT_LOCATION_METADATA"
            ):
                logger.info(
351
                    f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
352
353
354
355
356
357
358
359
360
361
                )

            set_global_expert_distribution_recorder(
                ExpertDistributionRecorder.init_new(
                    server_args,
                    get_global_expert_location_metadata(),
                    rank=self.tp_rank,
                )
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
362
        # Expert parallelism
363
364
365
366
367
        self.eplb_manager = (
            EPLBManager(self)
            if self.server_args.enable_eplb and (not self.is_draft_worker)
            else None
        )
368
        self.expert_location_updater = ExpertLocationUpdater()
369

370
        # Load the model
371
        self.sampler = Sampler()
372
        self.load_model()
373

374
        # Check if the model is using hybrid SWA
Hanming Lu's avatar
Hanming Lu committed
375
376
377
378
379
380
381
382
383
        if (
            not self.server_args.disable_hybrid_swa_memory
            and self.sliding_window_size is not None
            and self.sliding_window_size > 0
        ):
            architectures = self.model_config.hf_config.architectures
            if architectures and not any("Llama4" in arch for arch in architectures):
                self.is_hybrid = self.model_config.is_hybrid = True

384
        if config := self.mamba2_config:
385
386
            class_name = config.__class__.__name__
            logger.warning(f"{class_name} model detected, disable radix cache")
Yi Zhang's avatar
Yi Zhang committed
387
388
            self.server_args.disable_radix_cache = True

389
390
391
392
393
394
395
        # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
        # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
        # determine the number of layers.
        model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
        model_num_layers = (
            self.model_config.num_nextn_predict_layers
            if self.is_draft_worker and model_has_mtp_layers
396
397
398
399
            else max(
                self.model_config.num_hidden_layers,
                self.model_config.num_attention_layers,
            )
400
        )
401
402
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(self.model, "end_layer", model_num_layers)
403
        self.num_effective_layers = self.end_layer - self.start_layer
404
405
406
407
408
409
410
        assert (
            (not model_has_mtp_layers)
            or (self.spec_algorithm.is_none())
            or (
                (not self.spec_algorithm.is_none())
                and (self.num_effective_layers == model_num_layers)
            )
411
        ), "PP is not compatible with MTP models."
412

413
        # Apply torchao quantization
414
415
416
417
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
418
                self.model, get_global_server_args().torchao_config
419
            )
420

421
        # Apply torch TP if the model supports it
422
423
424
425
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

426
        # Init lora
427
        if server_args.enable_lora:
428
            self.init_lora_manager()
429

430
431
432
433
434
435
436
437
        # Init Double Sparsity
        if server_args.enable_double_sparsity:
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

438
439
        # Enable batch invariant mode
        if server_args.enable_deterministic_inference:
440
            from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
441
442
443

            enable_batch_invariant_mode()

444
        # Init memory pool and attention backends
445
446
        self.init_memory_pool(
            min_per_gpu_memory,
447
            server_args.max_running_requests,
448
449
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
450
451
452
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
453
            self.init_device_graphs()
454
        elif self.device in ["npu", "cpu"]:
455
456
            self.init_attention_backend()
            self.init_device_graphs()
Zhang, Liangang's avatar
Zhang, Liangang committed
457
        else:
458
            self.graph_runner = None
459
            self.graph_mem_usage = 0
Zhang, Liangang's avatar
Zhang, Liangang committed
460
            self.init_attention_backend()
461

James Liu's avatar
James Liu committed
462
463
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
lukec's avatar
lukec committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
            # load draft config
            draft_model_config = ModelConfig.from_server_args(
                server_args,
                model_path=(server_args.speculative_draft_model_path),
                is_draft_model=True,
            )

            try:
                # get the aux layer from draft model config
                eagle_config = getattr(
                    draft_model_config.hf_config, "eagle_config", None
                )
                eagle_aux_hidden_state_layer_ids = eagle_config[
                    "eagle_aux_hidden_state_layer_ids"
                ]
            except:
                # if there is no aux layer, set to None
                eagle_aux_hidden_state_layer_ids = None

            self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
James Liu's avatar
James Liu committed
484

485
486
487
    def model_specific_adjustment(self):
        server_args = self.server_args

488
489
490
        if (
            server_args.attention_backend == "intel_amx"
            and server_args.device == "cpu"
491
            and not _is_cpu_amx_available
492
493
494
495
496
497
        ):
            logger.info(
                "The current platform does not support Intel AMX, will fallback to torch_native backend."
            )
            server_args.attention_backend = "torch_native"

498
499
500
501
502
503
        if server_args.prefill_attention_backend is not None and (
            server_args.prefill_attention_backend
            == server_args.decode_attention_backend
        ):  # override the default attention backend
            server_args.attention_backend = server_args.prefill_attention_backend

504
505
506
507
508
509
510
511
512
513
514
515
516
        if (
            getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
            is not None
        ):
            if server_args.attention_backend is None:
                server_args.attention_backend = "dual_chunk_flash_attn"
                logger.info("Dual chunk attention is turned on by default.")
            elif server_args.attention_backend != "dual_chunk_flash_attn":
                raise ValueError(
                    "Dual chunk attention is enabled, but attention backend is set to "
                    f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
                )

517
        if server_args.attention_backend is None:
518
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
519
520
            Auto select the fastest attention backend.

521
522
523
524
525
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
526
527
                2.2 We will use Flashinfer backend on blackwell.
                2.3 Otherwise, we will use triton backend.
528
529
            """

530
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
531
                # MHA architecture
532
                if (
533
                    is_hopper_with_cuda_12_3()
534
535
536
537
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
538
539
                elif _is_hip:
                    server_args.attention_backend = "aiter"
540
541
                elif _is_npu:
                    server_args.attention_backend = "ascend"
542
543
544
545
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
546
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
547
                # MLA architecture
548
                if is_hopper_with_cuda_12_3():
549
                    server_args.attention_backend = "fa3"
550
551
                elif is_sm100_supported():
                    server_args.attention_backend = "flashinfer"
552
553
554
                elif _is_hip:
                    head_num = self.model_config.get_num_kv_heads(self.tp_size)
                    # TODO current aiter only support head number 16 or 128 head number
555
                    if head_num == 128 or head_num == 16:
556
557
558
                        server_args.attention_backend = "aiter"
                    else:
                        server_args.attention_backend = "triton"
559
560
                elif _is_npu:
                    server_args.attention_backend = "ascend"
561
562
                else:
                    server_args.attention_backend = "triton"
563
            logger.info(
564
                f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default."
565
            )
566
        elif self.use_mla_backend:
567
            if server_args.device != "cpu":
568
                if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
569
570
571
                    logger.info(
                        f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                    )
572
                else:
573
574
575
576
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
577
578
579
580
                if server_args.attention_backend != "intel_amx":
                    raise ValueError(
                        "MLA optimization not supported on CPU except for intel_amx backend."
                    )
581

582
583
584
585
586
587
588
589
590
591
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

592
        if server_args.enable_double_sparsity:
593
594
595
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
596
597
598
599
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True

        if self.is_multimodal:
600
601
602
            if not self.is_multimodal_chunked_prefill_supported:
                server_args.chunked_prefill_size = -1
                logger.info(
603
                    f"Automatically turn off --chunked-prefill-size as it is not supported for "
604
605
                    f"{self.model_config.hf_config.model_type}"
                )
606

607
608
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
609

610
        if not server_args.disable_chunked_prefix_cache:
611
            logger.info("Chunked prefix cache is turned on.")
612

kk's avatar
kk committed
613
614
615
616
        if server_args.attention_backend == "aiter":
            if self.model_config.context_len > 8192:
                self.mem_fraction_static *= 0.85

617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        if (
            server_args.enable_hierarchical_cache
            and server_args.hicache_io_backend == "kernel"
        ):
            # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
            if server_args.decode_attention_backend is None:
                if not self.use_mla_backend:
                    server_args.decode_attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
                else:
                    server_args.decode_attention_backend = (
                        "flashinfer" if is_sm100_supported() else "triton"
                    )
            elif server_args.decode_attention_backend == "fa3":
                server_args.hicache_io_backend = "direct"
                logger.warning(
                    "FlashAttention3 decode backend is not compatible with hierarchical cache. "
635
                    "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
636
637
                )

638
639
640
641
642
643
644
645
646
647
648
649
650
        if self.model_config.hf_config.model_type == "qwen3_vl_moe":
            if (
                quantization_config := getattr(
                    self.model_config.hf_config, "quantization_config", None
                )
            ) is not None:
                text_config = self.model_config.hf_text_config
                weight_block_size_n = quantization_config["weight_block_size"][0]
                if (
                    text_config.moe_intermediate_size
                    // (self.tp_size // self.moe_ep_size)
                ) % weight_block_size_n != 0:
                    raise ValueError(
651
652
                        f"For qwen3-vl-fp8 models, please make sure ({text_config.moe_intermediate_size=} // ({self.tp_size=} // {self.moe_ep_size=})) % {weight_block_size_n=} == 0. "
                        f"You can fix this by using arguments such as `--tp-size 8 --ep-size 8`"
653
654
                    )

655
    def init_torch_distributed(self):
656
        logger.info("Init torch distributed begin.")
657

658
659
660
661
662
663
664
665
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
666
667
        if self.device == "cuda":
            backend = "nccl"
668
        elif self.device == "xpu":
669
            backend = "xccl"
670
671
        elif self.device == "hpu":
            backend = "hccl"
672
673
        elif self.device == "cpu":
            backend = "gloo"
674
675
        elif self.device == "npu":
            backend = "hccl"
676

677
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
678
        if not self.server_args.enable_p2p_check:
679
680
            monkey_patch_p2p_access_check()

681
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
682
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
683
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
684
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
685
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
686
        set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
687
        set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
688
689

        if not self.is_draft_worker:
690
691
692
693
            if self.device == "cpu":
                if _is_cpu_amx_available:
                    # Bind OpenMP threads to CPU cores
                    torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
694
695
696
697

                    # Set local size to hint SGLang to use shared memory based AllReduce
                    os.environ["LOCAL_SIZE"] = str(self.tp_size)
                    torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
698
699
700
701
702

                    @torch.library.register_fake("sgl_kernel::shm_allgather")
                    def _(data, dim):
                        return torch.cat([data] * self.tp_size, dim=dim)

703
704
                else:
                    logger.warning(
705
                        "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
706
707
                    )

Mick's avatar
Mick committed
708
            # Only initialize the distributed environment on the target model worker.
709
710
            init_distributed_environment(
                backend=backend,
711
712
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
713
714
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
715
                timeout=self.server_args.dist_timeout,
716
            )
717
718
719
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
Cheng Wan's avatar
Cheng Wan committed
720
                expert_model_parallel_size=self.moe_ep_size,
721
                duplicate_tp_group=self.server_args.enable_pdmux,
722
                torch_compile=self.server_args.enable_piecewise_cuda_graph,
723
            )
724
            initialize_dp_attention(
725
726
                server_args=self.server_args,
                model_config=self.model_config,
727
            )
728

729
        min_per_gpu_memory = get_available_gpu_memory(
730
731
732
733
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
734
        )
735
        self.tp_group = get_tp_group()
736
        self.pp_group = get_pp_group()
737
        self.attention_tp_group = get_attention_tp_group()
738

739
        # Check memory for tensor parallelism
740
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
741
        if self.tp_size > 1 and not self.is_draft_worker:
742
            if min_per_gpu_memory < local_gpu_memory * 0.9:
743
744
745
746
747
748
749
750
751
752
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
753

754
755
756
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
757
        return min_per_gpu_memory
758

Lianmin Zheng's avatar
Lianmin Zheng committed
759
    def load_model(self):
760
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
761
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
762
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
763
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
764
765

        # This can reduce thread conflicts and speed up weight loading.
766
767
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
768
769
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
770
771
772
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
773
                self.server_args.dtype = "float16"
774
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
775
776
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
777

778
779
        set_cuda_arch()

780
        # Prepare the model config
781
782
783
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
784
            model_loader_extra_config=self.server_args.model_loader_extra_config,
785
786
787
788
            tp_rank=self.tp_rank,
            remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
            remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
            remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
789
        )
790
791
792
793
        if self.device == "cpu":
            self.model_config = adjust_config_with_unaligned_cpu_tp(
                self.model_config, self.load_config, self.tp_size
            )
794
795
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
796

797
798
799
800
801
802
803
804
805
806
807
808
809
810
        if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
            if self.tp_rank == 0:
                instance_ip = socket.gethostbyname(socket.gethostname())
                t = threading.Thread(
                    target=trigger_init_weights_send_group_for_remote_instance_request,
                    args=(
                        self.server_args.remote_instance_weight_loader_seed_instance_ip,
                        self.server_args.remote_instance_weight_loader_seed_instance_service_port,
                        self.server_args.remote_instance_weight_loader_send_weights_group_ports,
                        instance_ip,
                    ),
                )
                t.start()

811
        # Load the model
812
813
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
814
815
        monkey_patch_isinstance_for_vllm_base_layer()

816
817
818
819
        with self.memory_saver_adapter.region(
            GPU_MEMORY_TYPE_WEIGHTS,
            enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
        ):
820
821
822
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
823
                device_config=DeviceConfig(self.device, self.gpu_id),
824
            )
825
        monkey_patch_vllm_parallel_state(reverse=True)
826
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
827

828
829
        get_offloader().post_init()

bjmsong's avatar
bjmsong committed
830
831
832
833
834
835
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
836
837
838
839
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
840
841
842
843
844
845
846
847
848
849
850
851
852
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

853
        # Parse other args
Hanming Lu's avatar
Hanming Lu committed
854
855
856
857
858
        self.sliding_window_size = None
        if hasattr(self.model, "get_attention_sliding_window_size"):
            self.sliding_window_size = self.model.get_attention_sliding_window_size()
        elif self.model_config.attention_chunk_size is not None:
            self.sliding_window_size = self.model_config.attention_chunk_size
859
            logger.info(
Hanming Lu's avatar
Hanming Lu committed
860
861
862
                f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
            )

863
        self.dtype = self.model_config.dtype
864

865
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
866
        self.weight_load_mem_usage = before_avail_memory - after_avail_memory
867
        logger.info(
868
            f"Load weight end. "
869
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
870
            f"dtype={self.dtype}, "
871
            f"avail mem={after_avail_memory:.2f} GB, "
872
            f"mem usage={self.weight_load_mem_usage:.2f} GB."
873
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
874

875
876
877
878
879
880
881
882
883
884
885
886
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

887
    def update_expert_location(
888
889
890
        self,
        new_expert_location_metadata: ExpertLocationMetadata,
        update_layer_ids: List[int],
891
    ):
892
        self.expert_location_updater.update(
893
894
            self.model.routed_experts_weights_of_layer,
            new_expert_location_metadata,
895
            update_layer_ids=update_layer_ids,
896
897
898
899
            nnodes=self.server_args.nnodes,
            rank=self.tp_rank,
        )

900
901
902
903
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
904
        logger.info(
Chayenne's avatar
Chayenne committed
905
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
906
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
907
908
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
909
        target_device = torch.device(self.device)
910
        self.model_config.model_path = model_path
911
912
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
913
        # Only support DefaultModelLoader for now
914
        loader = get_model_loader(load_config, self.model_config)
915
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
916
917
            message = f"Failed to get model loader: {loader}."
            return False, message
918
919
920

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
921
                DefaultModelLoader.Source.init_new(config, self.model)
922
923
924
925
            )
            return iter

        def model_load_weights(model, iter):
926
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
927
928
            return model

929
        with set_default_torch_dtype(self.model_config.dtype):
930
            try:
931
                iter = get_weight_iter(self.model_config)
932
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
933
                message = f"Failed to get weights iterator: {e}."
934
935
936
937
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
938
939
940
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
941
942
                del iter
                gc.collect()
943
                iter = get_weight_iter(self.model_config)
944
945
946
947
948
949
950
951
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

952
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
953
        return True, "Succeeded to update model weights."
954

955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    def init_weights_send_group_for_remote_instance(
        self,
        master_address,
        ports,
        group_rank,
        world_size,
        group_name,
        backend="nccl",
    ):
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        ports_list = ports.split(",")
        assert (
            len(ports_list) == self.tp_size
        ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
        group_port = ports_list[self.tp_rank]
        group_name = f"{group_name}_{group_port}_{self.tp_rank}"

        logger.info(
            f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
            f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
        )

        torch.cuda.empty_cache()
        success = False
        message = ""
        try:
            self._weights_send_group[group_name] = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{group_port}",
                world_size=world_size,
                rank=group_rank,
                group_name=group_name,
                device_id=torch.device("cuda", self.gpu_id),
            )
            dist.barrier(group=self._weights_send_group[group_name])
            success = True
            message = (
                f"Succeeded to init group through {master_address}:{group_port} group."
            )
        except Exception as e:
            message = f"Failed to init group: {e}."
            logger.error(message)

        torch.cuda.empty_cache()
        return success, message

    def send_weights_to_remote_instance(
        self,
        master_address,
        ports,
        group_name,
    ):
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        ports_list = ports.split(",")
        assert (
            len(ports_list) == self.tp_size
        ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
        group_port = ports_list[self.tp_rank]
        group_name = f"{group_name}_{group_port}_{self.tp_rank}"

        if self._weights_send_group[group_name] is not None:
            send_group = self._weights_send_group[group_name]
        else:
            message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
            logger.error(message)
            return False, message

        torch.cuda.empty_cache()
        success = False
        message = ""
        try:
            for _, weights in self.model.named_parameters():
                torch.distributed.broadcast(
                    weights,
                    src=0,
                    group=send_group,
                )
            success = True
            message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
        except Exception as e:
            message = f"Failed to send weights: {e}."
            logger.error(message)

        # destroy the process group after sending weights
        del self._weights_send_group[group_name]
        torch.distributed.distributed_c10d.destroy_process_group(send_group)
        torch.cuda.empty_cache()
        return success, message

1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
1080
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
1081
1082
1083
        )

        try:
1084
            self._model_update_group[group_name] = init_custom_process_group(
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
    def destroy_weights_update_group(self, group_name):
        try:
            if group_name in self._model_update_group:
                pg = self._model_update_group.pop(group_name)
                torch.distributed.destroy_process_group(pg)
                return True, "Succeeded to destroy custom process group."
            else:
                return False, "The group to be destroyed does not exist."
        except Exception as e:
            message = f"Failed to destroy custom process group: {e}."
            logger.error(message)
            return False, message

1110
    def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """

1121
1122
1123
1124
        assert group_name in self._model_update_group, (
            f"Group {group_name} not in {list(self._model_update_group.keys())}. "
            "Please call `init_weights_update_group` first."
        )
1125
1126

        try:
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
            weights = []
            handles = []
            for name, dtype, shape in zip(names, dtypes, shapes):
                target_dtype = (
                    dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
                )
                weight = torch.empty(shape, dtype=target_dtype, device=self.device)
                handles.append(
                    torch.distributed.broadcast(
                        weight,
                        src=0,
                        group=self._model_update_group[group_name],
                        async_op=True,
                    )
                )
                weights.append((name, weight))
            for handle in handles:
                handle.wait()

            self.model.load_weights(weights)
1147
            return True, "Succeeded to update parameter online."
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

1158
1159
1160
1161
1162
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
1163
        monkey_patch_torch_reductions()
1164
1165
1166
1167
1168
1169
        if load_format == "flattened_bucket":
            # Handle flattened bucket format
            return self._update_weights_from_flattened_bucket(
                flattened_tensor_bucket_dict=named_tensors
            )

1170
        # We need to get device after patch otherwise the device would be wrong
1171
1172
        self.device_module = torch.get_device_module(self.device)
        infered_device = self.device_module.current_device()
1173

1174
        named_tensors = [
1175
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
1176
1177
1178
1179
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
1180
1181
1182
        elif load_format in self.server_args.custom_weight_loader:
            custom_loader = dynamic_import(load_format)
            custom_loader(self.model, named_tensors)
1183
1184
1185
1186
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
1187
        return True, "Success"
1188

1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
    def _update_weights_from_flattened_bucket(
        self,
        flattened_tensor_bucket_dict,
    ):
        """Handle flattened bucket format for weight updates"""
        flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
        metadata = flattened_tensor_bucket_dict["metadata"]

        # Convert metadata dict to our format
        converted_metadata = []
        for meta in metadata:
            converted_meta = FlattenedTensorMetadata(
                name=meta.name,
                shape=meta.shape,
                dtype=meta.dtype,
                start_idx=meta.start_idx,
                end_idx=meta.end_idx,
                numel=meta.numel,
            )
            converted_metadata.append(converted_meta)

        # Create bucket and reconstruct tensors
        bucket = FlattenedTensorBucket(
            flattened_tensor=flattened_tensor, metadata=converted_metadata
        )
        reconstructed_tensors = bucket.reconstruct_tensors()

        # Load the reconstructed tensors using the standard method
        self.model.load_weights(reconstructed_tensors)

        return True, "Success"

1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

1238
1239
1240
1241
1242
1243
1244
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
1245
            lora_backend=self.server_args.lora_backend,
1246
1247
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
1248
1249
            max_lora_rank=self.server_args.max_lora_rank,
            target_modules=self.server_args.lora_target_modules,
1250
            lora_paths=self.server_args.lora_paths,
1251
            server_args=self.server_args,
1252
        )
1253

1254
    def load_lora_adapter(self, lora_ref: LoRARef):
1255
1256
1257
        """Load a new lora adapter from disk or huggingface."""

        logger.info(
1258
            f"LoRA adapter loading starts: {lora_ref}. "
1259
1260
1261
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1262
        result = self.lora_manager.load_lora_adapter(lora_ref)
1263
1264

        logger.info(
1265
            f"LoRA adapter loading completes: {lora_ref}. "
1266
1267
1268
1269
1270
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result

1271
    def unload_lora_adapter(self, lora_ref: LoRARef):
1272
1273
1274
        """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""

        logger.info(
1275
            f"LoRA adapter unloading starts: {lora_ref}. "
1276
1277
1278
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1279
        result = self.lora_manager.unload_lora_adapter(lora_ref)
1280
1281

        logger.info(
1282
            f"LoRA adapter unloading completes: {lora_ref}. "
1283
1284
1285
1286
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result
1287

1288
    def profile_max_num_token(self, total_gpu_memory: int):
1289
        available_gpu_memory = get_available_gpu_memory(
1290
1291
1292
1293
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
1294
        )
1295
1296
1297
1298
1299
        if self.is_draft_worker:
            num_layers = getattr(
                self.model_config.hf_config,
                "num_nextn_predict_layers",
                self.num_effective_layers,
1300
            )
1301
1302
        elif config := self.mambaish_config:
            num_layers = len(config.full_attention_layer_ids)
1303
1304
1305
        else:
            num_layers = self.num_effective_layers
        if self.use_mla_backend:
1306
1307
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
1308
                * num_layers
1309
                * torch._utils._element_size(self.kv_cache_dtype)
1310
            )
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
            # Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
            if is_deepseek_nsa(self.model_config.hf_config):
                index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
                indexer_size_per_token = (
                    index_head_dim
                    + index_head_dim // NSATokenToKVPool.quant_block_size * 4
                )
                element_size = torch._utils._element_size(
                    NSATokenToKVPool.index_k_with_scale_buffer_dtype
                )
                cell_size += indexer_size_per_token * num_layers * element_size
1322
1323
        else:
            cell_size = (
1324
                self.model_config.get_num_kv_heads(get_attention_tp_size())
1325
                * self.model_config.head_dim
1326
                * num_layers
1327
                * 2
1328
                * torch._utils._element_size(self.kv_cache_dtype)
1329
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1330
1331
1332
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
1333
1334
        if self.mambaish_config is not None:
            rest_memory = self.handle_max_mamba_cache(rest_memory)
1335
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1336
1337
        return max_num_token

1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
    def handle_max_mamba_cache(self, total_rest_memory):
        config = self.mambaish_config
        server_args = self.server_args
        assert config is not None

        speculativa_ratio = (
            0
            if server_args.speculative_num_draft_tokens is None
            else server_args.speculative_num_draft_tokens
        )
        if (
            server_args.disable_radix_cache
            or config.mamba2_cache_params.mamba_cache_per_req == 0
        ):
            # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
            if server_args.max_mamba_cache_size is None:
                if server_args.max_running_requests is not None:
                    server_args.max_mamba_cache_size = server_args.max_running_requests
                else:
                    server_args.max_mamba_cache_size = 512
        else:
            # allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
            # solve the equations:
            # 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
            # 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
            mamba_state_memory_raw = (
                total_rest_memory
                * server_args.mamba_full_memory_ratio
                / (1 + server_args.mamba_full_memory_ratio)
            )
            # calculate the max_mamba_cache_size based on the given total mamba memory
            server_args.max_mamba_cache_size = int(
                (mamba_state_memory_raw * (1 << 30))
                // config.mamba2_cache_params.mamba_cache_per_req
                // (1 + speculativa_ratio)
            )

        if self.hybrid_gdn_config is not None:
            server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
                server_args.dp_size if server_args.enable_dp_attention else 1
            )
        mamba_state_memory = (
            server_args.max_mamba_cache_size
            * config.mamba2_cache_params.mamba_cache_per_req
            * (1 + speculativa_ratio)
            / (1 << 30)
        )
        return total_rest_memory - mamba_state_memory

Yi Zhang's avatar
Yi Zhang committed
1387
    @property
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
    def hybrid_gdn_config(self):
        config = self.model_config.hf_config
        if isinstance(config, Qwen3NextConfig):
            return config
        return None

    @property
    def mamba2_config(self):
        config = self.model_config.hf_config
        if isinstance(config, FalconH1Config | NemotronHConfig):
            return config
        return None

    @property
    def mambaish_config(self):
        return self.mamba2_config or self.hybrid_gdn_config
Yi Zhang's avatar
Yi Zhang committed
1404

tarinkk's avatar
tarinkk committed
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
    def set_num_token_hybrid(self):
        if (
            "Llama4ForConditionalGeneration"
            in self.model_config.hf_config.architectures
        ):
            temp_ratio = (
                (1 - self.is_hybrid)
                + self.is_hybrid
                * self.attention_chunk_size
                / self.model_config.context_len
            )
            self.swa_max_total_num_tokens = (
                4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.full_max_total_num_tokens = (
                4 * self.max_total_num_tokens
                - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.swa_max_total_num_tokens = int(
                self.swa_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.full_max_total_num_tokens = int(
                self.full_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens
        else:
Hanming Lu's avatar
Hanming Lu committed
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
            assert self.sliding_window_size is not None and self.sliding_window_size > 0
            full_attention_layer_ids = []
            swa_attention_layer_ids = []

            try:
                layers = self.model.model.layers
            except:
                try:
                    layers = self.model.language_model.model.layers
                except:
1445
1446
1447
1448
1449
                    try:
                        layers = self.model.language_model.layers
                    except:
                        self.is_hybrid = False
                        return
Hanming Lu's avatar
Hanming Lu committed
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484

            for layer in layers:
                if (
                    layer.self_attn.attn.sliding_window_size is None
                    or layer.self_attn.attn.sliding_window_size == -1
                ):
                    full_attention_layer_ids.append(layer.layer_id)
                else:
                    swa_attention_layer_ids.append(layer.layer_id)
            self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
            self.model_config.full_attention_layer_ids = full_attention_layer_ids

            # Algorithm:
            # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
            # - Find total # of tokens available across layers.
            # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
            total_tokens = (
                self.max_total_num_tokens * self.model_config.num_hidden_layers
            )
            full_layers_num = len(full_attention_layer_ids)
            swa_layers_num = len(swa_attention_layer_ids)
            swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio

            # Solve the equations:
            # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
            # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
            denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
            self.full_max_total_num_tokens = int(total_tokens / denominator)
            self.swa_max_total_num_tokens = int(
                self.full_max_total_num_tokens * swa_full_tokens_ratio
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens

            logger.info(
                f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
tarinkk's avatar
tarinkk committed
1485
1486
            )

1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
    def can_run_piecewise_cuda_graph(self):
        if self.server_args.disable_cuda_graph:
            log_info_on_rank0(
                logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
            )
            return False
        if self.server_args.enable_torch_compile:
            log_info_on_rank0(
                logger,
                "Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
            )
            return False
        if self.pp_size > 1:
            # TODO(yuwei): support PP
            log_info_on_rank0(
                logger,
                "Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
            )
            return False
        return True

1508
    def init_memory_pool(
1509
1510
        self,
        total_gpu_memory: int,
1511
1512
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
1513
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1514
        # Determine the kv cache dtype
1515
        if self.server_args.kv_cache_dtype == "auto":
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
            quant_config = getattr(self.model, "quant_config", None)
            kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
            if (
                isinstance(kv_cache_quant_algo, str)
                and kv_cache_quant_algo.upper() == "FP8"
            ):
                if _is_hip:
                    self.kv_cache_dtype = torch.float8_e4m3fnuz
                else:
                    self.kv_cache_dtype = torch.float8_e4m3fn
            else:
                self.kv_cache_dtype = self.dtype
1528
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1529
            if _is_hip:  # Using natively supported format
HAI's avatar
HAI committed
1530
1531
1532
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
1533
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
1534
1535
1536
            if _is_hip:  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e4m3fnuz
            else:
bjmsong's avatar
bjmsong committed
1537
                self.kv_cache_dtype = torch.float8_e4m3fn
1538
1539
1540
1541
1542
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

1543
1544
        log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")

1545
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
Lianmin Zheng's avatar
Lianmin Zheng committed
1546
1547
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )
1559

1560
        if self.mambaish_config is not None:
1561
1562
1563
1564
1565
1566
1567
1568
            ratio = (
                MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
                if not self.server_args.disable_radix_cache
                else 1
            )
            max_num_reqs = min(
                max_num_reqs, self.server_args.max_mamba_cache_size // ratio
            )
1569

1570
        if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1571
1572
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1573
                max_num_reqs = self.server_args.max_num_reqs
1574
            else:
1575
1576
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
1577
1578
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
1579
1580
1581
1582
1583
1584
1585
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
1586
1587
                    + 100
                )
1588
1589
1590
1591
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
1592

1593
1594
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
1595
                logging.warning(
1596
1597
1598
1599
1600
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
1601

1602
1603
1604
1605
1606
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
        # different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
        if self.pp_size > 1:
            tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
            torch.distributed.all_reduce(
                tensor,
                op=torch.distributed.ReduceOp.MIN,
                group=get_world_group().cpu_group,
            )
            self.max_total_num_tokens = tensor.item()

tarinkk's avatar
tarinkk committed
1617
1618
1619
1620
        # create token size for hybrid cache
        if self.is_hybrid:
            self.set_num_token_hybrid()

1621
        if self.max_total_num_tokens <= 0:
1622
            raise RuntimeError(
1623
1624
                f"Not enough memory. Please try to increase --mem-fraction-static. "
                f"Current value: {self.server_args.mem_fraction_static=}"
1625
            )
1626

Lianmin Zheng's avatar
Lianmin Zheng committed
1627
        # Initialize req_to_token_pool
1628
        if self.req_to_token_pool is None:
1629
1630
1631
1632
1633
            # FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
            extra_max_context_len = 4
            if self.server_args.speculative_num_draft_tokens is not None:
                extra_max_context_len += self.server_args.speculative_num_draft_tokens

Byron Hsu's avatar
Byron Hsu committed
1634
1635
1636
1637
1638
1639
1640
1641
            if self.server_args.disaggregation_mode == "decode":
                from sglang.srt.disaggregation.decode import DecodeReqToTokenPool

                # subscribe memory for pre-allocated requests
                # if max_num_reqs <= 32, we pre-allocate 2x requests
                pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
                self.req_to_token_pool = DecodeReqToTokenPool(
                    size=max_num_reqs,
1642
1643
                    max_context_len=self.model_config.context_len
                    + extra_max_context_len,
Byron Hsu's avatar
Byron Hsu committed
1644
1645
1646
1647
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    pre_alloc_size=pre_alloc_size,
                )
1648
            elif config := self.mambaish_config:
Yi Zhang's avatar
Yi Zhang committed
1649
1650
                self.req_to_token_pool = HybridReqToTokenPool(
                    size=max_num_reqs,
1651
                    mamba_size=self.server_args.max_mamba_cache_size,
Yi Zhang's avatar
Yi Zhang committed
1652
1653
1654
1655
                    max_context_len=self.model_config.context_len
                    + extra_max_context_len,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
1656
                    cache_params=config.mamba2_cache_params,
Yi Zhang's avatar
Yi Zhang committed
1657
1658
                    speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
                )
Byron Hsu's avatar
Byron Hsu committed
1659
1660
1661
            else:
                self.req_to_token_pool = ReqToTokenPool(
                    size=max_num_reqs,
1662
1663
                    max_context_len=self.model_config.context_len
                    + extra_max_context_len,
Byron Hsu's avatar
Byron Hsu committed
1664
1665
1666
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
1667
1668
1669
1670
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

Lianmin Zheng's avatar
Lianmin Zheng committed
1671
        # Initialize token_to_kv_pool
fzyzcjy's avatar
fzyzcjy committed
1672
        is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
1673
1674
1675
1676
1677
1678
1679
1680
        if self.server_args.attention_backend == "ascend":
            if self.use_mla_backend:
                self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    kv_lora_rank=self.model_config.kv_lora_rank,
                    qk_rope_head_dim=self.model_config.qk_rope_head_dim,
fzyzcjy's avatar
fzyzcjy committed
1681
                    index_head_dim=self.model_config.index_head_dim,
Lianmin Zheng's avatar
Lianmin Zheng committed
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
                    layer_num=self.num_effective_layers,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
                )
            else:
                self.token_to_kv_pool = AscendTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    layer_num=self.model_config.num_hidden_layers,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
fzyzcjy's avatar
fzyzcjy committed
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
        elif self.use_mla_backend and is_nsa_model:
            self.token_to_kv_pool = NSATokenToKVPool(
                self.max_total_num_tokens,
                page_size=self.page_size,
                dtype=self.kv_cache_dtype,
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.num_effective_layers,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
                start_layer=self.start_layer,
                end_layer=self.end_layer,
                index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
            )
1715
        elif self.use_mla_backend:
fzyzcjy's avatar
fzyzcjy committed
1716
            assert not is_nsa_model
1717
1718
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1719
                page_size=self.page_size,
1720
                dtype=self.kv_cache_dtype,
1721
1722
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1723
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
1724
                device=self.device,
1725
                enable_memory_saver=self.server_args.enable_memory_saver,
1726
1727
                start_layer=self.start_layer,
                end_layer=self.end_layer,
1728
            )
Shuo Yang's avatar
Shuo Yang committed
1729
1730
1731
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1732
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
1733
                dtype=self.kv_cache_dtype,
1734
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
1735
                head_dim=self.model_config.head_dim,
1736
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
1737
1738
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
1739
                enable_memory_saver=self.server_args.enable_memory_saver,
1740
1741
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
1742
            )
1743
        else:
tarinkk's avatar
tarinkk committed
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
            if self.is_hybrid:
                self.token_to_kv_pool = SWAKVPool(
                    size=self.full_max_total_num_tokens,
                    size_swa=self.swa_max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
                    full_attention_layer_ids=self.model_config.full_attention_layer_ids,
                    enable_kvcache_transpose=False,
                    device=self.device,
                )
1758
            elif config := self.mambaish_config:
Yi Zhang's avatar
Yi Zhang committed
1759
                self.token_to_kv_pool = HybridLinearKVPool(
1760
                    page_size=self.page_size,
Yi Zhang's avatar
Yi Zhang committed
1761
1762
1763
1764
1765
1766
1767
1768
                    size=self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    # if draft worker, we only need 1 attention layer's kv pool
                    full_attention_layer_ids=(
1769
                        [0] if self.is_draft_worker else config.full_attention_layer_ids
Yi Zhang's avatar
Yi Zhang committed
1770
1771
1772
1773
                    ),
                    enable_kvcache_transpose=False,
                    device=self.device,
                )
tarinkk's avatar
tarinkk committed
1774
1775
            else:
                self.token_to_kv_pool = MHATokenToKVPool(
Lianmin Zheng's avatar
Lianmin Zheng committed
1776
                    self.max_total_num_tokens,
tarinkk's avatar
tarinkk committed
1777
                    page_size=self.page_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
1778
                    dtype=self.kv_cache_dtype,
tarinkk's avatar
tarinkk committed
1779
1780
1781
1782
1783
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    layer_num=self.num_effective_layers,
Lianmin Zheng's avatar
Lianmin Zheng committed
1784
                    device=self.device,
tarinkk's avatar
tarinkk committed
1785
1786
1787
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
1788
1789
1790
                    enable_kv_cache_copy=(
                        self.server_args.speculative_algorithm is not None
                    ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1791
                )
tarinkk's avatar
tarinkk committed
1792

Lianmin Zheng's avatar
Lianmin Zheng committed
1793
        # Initialize token_to_kv_pool_allocator
Lianmin Zheng's avatar
Lianmin Zheng committed
1794
        need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
tarinkk's avatar
tarinkk committed
1795
        if self.token_to_kv_pool_allocator is None:
1796
            if _is_npu and (
1797
1798
                self.server_args.attention_backend == "ascend"
                or self.hybrid_gdn_config is not None
1799
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1800
1801
1802
1803
1804
1805
1806
1807
                self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                    need_sort=need_sort,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1808
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
                if self.page_size == 1:
                    if self.is_hybrid:
                        self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
                            self.full_max_total_num_tokens,
                            self.swa_max_total_num_tokens,
                            dtype=self.kv_cache_dtype,
                            device=self.device,
                            kvcache=self.token_to_kv_pool,
                            need_sort=need_sort,
                        )
                    else:
                        self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                            self.max_total_num_tokens,
                            dtype=self.kv_cache_dtype,
                            device=self.device,
                            kvcache=self.token_to_kv_pool,
                            need_sort=need_sort,
                        )
1827
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1828
1829
                    assert not self.is_hybrid
                    self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1830
1831
1832
1833
1834
                        self.max_total_num_tokens,
                        page_size=self.page_size,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
1835
                        need_sort=need_sort,
1836
                    )
1837
1838
1839
        else:
            assert self.is_draft_worker

1840
        logger.info(
1841
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
1842
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
1843
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1844

Lianmin Zheng's avatar
Lianmin Zheng committed
1845
1846
1847
1848
1849
1850
1851
1852
1853
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

1854
1855
    def init_attention_backend(self):
        """Init attention kernel backend."""
1856
        if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1857
1858
1859
1860
1861
            self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
        else:
            self.attn_backend = self._get_attention_backend()

    def _get_attention_backend(self):
1862
        """Init attention kernel backend."""
1863
1864
        self.prefill_attention_backend_str, self.decode_attention_backend_str = (
            self.server_args.get_attention_backends()
1865
        )
1866

1867
1868
1869
1870
1871
1872
        if self.decode_attention_backend_str != self.prefill_attention_backend_str:
            from sglang.srt.layers.attention.hybrid_attn_backend import (
                HybridAttnBackend,
            )

            attn_backend = HybridAttnBackend(
1873
                self,
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
                decode_backend=self._get_attention_backend_from_str(
                    self.decode_attention_backend_str
                ),
                prefill_backend=self._get_attention_backend_from_str(
                    self.prefill_attention_backend_str
                ),
            )
            logger.info(
                f"Using hybrid attention backend for decode and prefill: "
                f"decode_backend={self.decode_attention_backend_str}, "
                f"prefill_backend={self.prefill_attention_backend_str}."
            )
            logger.warning(
1887
1888
                "Warning: Attention backend specified by --attention-backend or default backend might be overridden."
                "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1889
1890
1891
1892
1893
1894
            )
        else:
            attn_backend = self._get_attention_backend_from_str(
                self.server_args.attention_backend
            )

1895
1896
1897
1898
        (
            get_global_server_args().prefill_attention_backend,
            get_global_server_args().decode_attention_backend,
        ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
1899
1900
1901
        return attn_backend

    def _get_attention_backend_from_str(self, backend_str: str):
1902
        if backend_str not in ATTENTION_BACKENDS:
1903
            raise ValueError(f"Invalid attention backend: {backend_str}")
1904
1905
        full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
        return attn_backend_wrapper(self, full_attention_backend)
1906

Shuo Yang's avatar
Shuo Yang committed
1907
1908
1909
1910
1911
1912
1913
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1914
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1915
1916
1917
1918
1919
1920
1921
1922
1923
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1924
    def init_device_graphs(self):
1925
        """Capture device graphs."""
1926
        self.graph_runner = None
1927
        self.graph_mem_usage = 0
1928

1929
        if not self.is_generation:
1930
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1931
1932
            return

1933
1934
1935
1936
        if self.device != "cpu" and self.server_args.disable_cuda_graph:
            return

        if self.device == "cpu" and not self.server_args.enable_torch_compile:
1937
            return
1938

1939
        tic = time.perf_counter()
1940
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1941
        logger.info(
1942
            f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1943
        )
1944
1945
1946
1947
1948
1949
        graph_runners = defaultdict(
            lambda: CudaGraphRunner,
            {
                "cpu": CPUGraphRunner,
                "npu": NPUGraphRunner,
            },
1950
        )
1951
1952
        self.graph_runner = graph_runners[self.device](self)

1953
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1954
        self.graph_mem_usage = before_mem - after_mem
1955
        logger.info(
1956
1957
            f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
            f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1958
        )
1959

1960
1961
    def init_threads_binding(self):
        omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
1962
1963
        cpu_ids_by_node = get_cpu_ids_by_node()
        n_numa_node = len(cpu_ids_by_node)
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
        if omp_cpuids == "all":
            assert self.tp_size <= n_numa_node, (
                f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
                f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
                f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
                f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
                f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
                f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
                f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
                f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
            )
            if self.tp_size < n_numa_node:
                logger.warning(
                    f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
                )
            self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
        else:
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
            threads_bind_list = omp_cpuids.split("|")
            assert self.tp_size == len(threads_bind_list), (
                f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
                f"Please double check your settings."
            )
            self.local_omp_cpuid = threads_bind_list[self.tp_rank]
            if self.tp_size > n_numa_node:
                logger.warning(
                    f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
                    f"in this case the available memory amount of each rank cannot be determined in prior. "
                    f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
                )
1993

1994
    def apply_torch_tp(self):
1995
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1996
        from sglang.srt.layers.model_parallel import tensor_parallel
1997
1998
1999
2000

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

2001
    def forward_decode(
Cheng Wan's avatar
Cheng Wan committed
2002
2003
2004
2005
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
2006
    ) -> LogitsProcessorOutput:
Cheng Wan's avatar
Cheng Wan committed
2007
2008
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)
2009
2010
2011
2012
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
2013
        return self.model.forward(
2014
2015
2016
2017
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
2018
2019
        )

2020
    def forward_extend(
2021
2022
2023
2024
2025
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
2026
2027
2028
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

2029
2030
2031
2032
2033
2034
2035
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
2036
2037
2038
2039
2040

        if self.piecewise_cuda_graph_runner is not None:
            if self.piecewise_cuda_graph_runner.can_run(forward_batch):
                return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)

2041
2042
2043
2044
2045
2046
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2047

2048
2049
2050
2051
2052
2053
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
2054
        return self.model.forward(
2055
2056
2057
2058
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
2059
2060
        )

2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
    def forward_split_prefill(
        self,
        forward_batch: ForwardBatch,
        reinit_attn_backend: bool = False,
        forward_count: int = 1,
    ) -> LogitsProcessorOutput:
        if forward_batch.split_index == 0 or reinit_attn_backend:
            self.attn_backend.init_forward_metadata(forward_batch)
        next_split_index = min(
            forward_batch.split_index + forward_count,
            self.model_config.num_hidden_layers,
        )
        ret = self.model.forward_split_prefill(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            (forward_batch.split_index, next_split_index),
        )
        forward_batch.split_index = next_split_index
        return ret

2082
    def forward(
2083
2084
2085
2086
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
2087
2088
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
2089
2090
2091
2092
2093
2094
2095
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
        self.forward_pass_id += 1

        with get_global_expert_distribution_recorder().with_forward_pass(
            self.forward_pass_id,
            forward_batch,
        ):
2096
            output = self._forward_raw(
2097
2098
2099
2100
2101
                forward_batch,
                skip_attn_backend_init,
                pp_proxy_tensors,
                reinit_attn_backend,
                split_forward_count,
2102
2103
            )

2104
        if self.eplb_manager is not None:
2105
            self.eplb_manager.on_forward_pass_end()
2106
2107
2108

        return output

2109
2110
2111
2112
2113
    def _forward_raw(
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool,
        pp_proxy_tensors: Optional[PPProxyTensors],
2114
2115
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
2116
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
2117
2118
2119
2120
2121
2122
2123
        mode_check = (
            forward_batch.forward_mode.is_cpu_graph
            if self.device == "cpu"
            else forward_batch.forward_mode.is_cuda_graph
        )
        can_run_graph = bool(
            mode_check()
2124
2125
            and self.graph_runner
            and self.graph_runner.can_run(forward_batch)
2126
        )
2127
2128

        if can_run_graph:
2129
            ret = self.graph_runner.replay(
2130
2131
2132
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
2133
            )
2134
            return ret, can_run_graph
Cheng Wan's avatar
Cheng Wan committed
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145

        # For MLP sync
        if forward_batch.global_num_tokens_cpu is not None:
            forward_batch.prepare_mlp_sync_batch(self)

        if forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
            )
2146
        elif forward_batch.forward_mode.is_extend():
2147
            ret = self.forward_extend(
2148
2149
2150
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
2151
            )
2152
2153
2154
2155
2156
2157
        elif forward_batch.forward_mode.is_split_prefill():
            ret = self.forward_split_prefill(
                forward_batch,
                reinit_attn_backend=reinit_attn_backend,
                forward_count=split_forward_count,
            )
Ke Bao's avatar
Ke Bao committed
2158
        elif forward_batch.forward_mode.is_idle():
2159
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
2160
        else:
2161
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
2162

2163
2164
2165
2166
        if (
            forward_batch.global_num_tokens_cpu is not None
            and self.pp_group.is_last_rank
        ):
Cheng Wan's avatar
Cheng Wan committed
2167
2168
            forward_batch.post_forward_mlp_sync_batch(ret)

2169
        return ret, can_run_graph
2170

2171
2172
2173
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
2174
2175
2176
2177
2178
        # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
        #       was executed after we processed last batch's results.

        # Calculate logits bias and apply it to next_token_logits.
        sampling_info.update_regex_vocab_mask()
2179
2180
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
2201

2202
        self._preprocess_logits(logits_output, forward_batch.sampling_info)
2203
2204
2205
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
2206
            forward_batch.sampling_info,
2207
2208
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
2209
            forward_batch.token_ids_logprobs,
2210
2211
2212
2213
2214
2215
            # For prefill, we only use the position of the last token.
            (
                forward_batch.positions
                if forward_batch.forward_mode.is_decode()
                else forward_batch.seq_lens - 1
            ),
2216
        )
2217
2218
        return next_token_ids

2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
    def compute_logprobs_only(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> None:
        """
        Compute token_ids_logprobs without performing sampling.

        Optimized path for prefill-only requests that need token_ids_logprobs but don't
        require next token generation. Skips expensive sampling operations
        while still providing requested probability information.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output
        """
        if not forward_batch.token_ids_logprobs:
            return

        # Preprocess logits (same as in sample method)
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

        # Delegate to sampler for logprob-only computation
        # This populates logits_output with requested token probabilities
        self.sampler.compute_logprobs_only(
            logits_output,
            forward_batch.sampling_info,
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
            forward_batch.token_ids_logprobs,
        )

Yineng Zhang's avatar
Yineng Zhang committed
2251
2252
2253
2254
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
2255
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
2256
2257
        if rope_scaling is None:
            return False
2258
2259
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
2260

2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

2277
2278
2279
2280
2281
2282
2283

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


2284
def _unwrap_tensor(tensor, tp_rank, device):
2285
    if isinstance(tensor, LocalSerializedTensor):
2286
        tensor = tensor.get(tp_rank)
2287
    return tensor.to(device)
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])