model_runner.py 94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
18
import inspect
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
23
import socket
import threading
24
import time
25
from collections import defaultdict
26
from dataclasses import dataclass
27
from typing import Callable, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29

import torch
30
import torch.distributed as dist
31

Ke Bao's avatar
Ke Bao committed
32
33
34
35
36
37
from sglang.srt.configs import (
    FalconH1Config,
    KimiLinearConfig,
    NemotronHConfig,
    Qwen3NextConfig,
)
38
from sglang.srt.configs.device_config import DeviceConfig
39
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
fzyzcjy's avatar
fzyzcjy committed
40
41
42
43
44
45
from sglang.srt.configs.model_config import (
    AttentionArch,
    ModelConfig,
    get_nsa_index_head_dim,
    is_deepseek_nsa,
)
46
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
47
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
48
49
50
from sglang.srt.debug_utils.tensor_dump_forward_hook import (
    register_forward_hook_for_model,
)
51
from sglang.srt.distributed import (
52
    get_pp_group,
zhyncs's avatar
zhyncs committed
53
    get_tp_group,
54
    get_world_group,
zhyncs's avatar
zhyncs committed
55
56
    init_distributed_environment,
    initialize_model_parallel,
57
    set_custom_all_reduce,
58
    set_mscclpp_all_reduce,
59
    set_symm_mem_all_reduce,
zhyncs's avatar
zhyncs committed
60
)
61
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
62
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
fzyzcjy's avatar
fzyzcjy committed
63
64
65
66
67
68
69
70
71
72
73
74
75
from sglang.srt.eplb.eplb_manager import EPLBManager
from sglang.srt.eplb.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
    set_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_location import (
    ExpertLocationMetadata,
    compute_initial_expert_location_metadata,
    get_global_expert_location_metadata,
    set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
76
from sglang.srt.layers import deep_gemm_wrapper
77
78
79
80
from sglang.srt.layers.attention.attention_registry import (
    ATTENTION_BACKENDS,
    attn_backend_wrapper,
)
81
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
82
83
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
84
    get_attention_tp_size,
85
86
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
87
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
88
from sglang.srt.layers.sampler import Sampler
89
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
90
from sglang.srt.lora.lora_manager import LoRAManager
91
from sglang.srt.lora.lora_registry import LoRARef
92
93
94
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    PagedTokenToKVPoolAllocator,
tarinkk's avatar
tarinkk committed
95
    SWATokenToKVPoolAllocator,
96
97
    TokenToKVPoolAllocator,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
98
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
99
from sglang.srt.mem_cache.memory_pool import (
100
101
    AscendMLAPagedTokenToKVPool,
    AscendTokenToKVPool,
Shuo Yang's avatar
Shuo Yang committed
102
    DoubleSparseTokenToKVPool,
Yi Zhang's avatar
Yi Zhang committed
103
104
    HybridLinearKVPool,
    HybridReqToTokenPool,
105
106
    MHATokenToKVPool,
    MLATokenToKVPool,
fzyzcjy's avatar
fzyzcjy committed
107
    NSATokenToKVPool,
108
    ReqToTokenPool,
tarinkk's avatar
tarinkk committed
109
    SWAKVPool,
110
)
111
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
112
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
113
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
114
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
115
116
117
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
    PiecewiseCudaGraphRunner,
)
118
from sglang.srt.model_loader import get_model
119
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
120
121
122
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
    trigger_init_weights_send_group_for_remote_instance_request,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
123
from sglang.srt.model_loader.utils import set_default_torch_dtype
124
from sglang.srt.model_loader.weight_utils import default_weight_loader
125
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
126
127
128
129
130
from sglang.srt.server_args import (
    ServerArgs,
    get_global_server_args,
    set_global_server_args_for_scheduler,
)
131
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
132
from sglang.srt.utils import (
133
    MultiprocessingSerializer,
134
    cpu_has_amx_support,
135
    dynamic_import,
136
    enable_show_time_cost,
137
    get_available_gpu_memory,
138
    get_bool_env_var,
139
    get_cpu_ids_by_node,
140
    init_custom_process_group,
HAI's avatar
HAI committed
141
    is_hip,
142
    is_npu,
143
    log_info_on_rank0,
144
    monkey_patch_p2p_access_check,
145
    set_cuda_arch,
146
    slow_rank_detector,
147
    xpu_has_xmx_support,
148
)
149
150
151
152
153
from sglang.srt.utils.offloader import (
    create_offloader_from_server_args,
    get_offloader,
    set_offloader,
)
154
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
155
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
156
157
158
159
from sglang.srt.weight_sync.tensor_bucket import (
    FlattenedTensorBucket,
    FlattenedTensorMetadata,
)
160

161
162
163
164
165
166
167
168
169
170
MLA_ATTENTION_BACKENDS = [
    "aiter",
    "flashinfer",
    "fa3",
    "fa4",
    "triton",
    "flashmla",
    "cutlass_mla",
    "trtllm_mla",
    "ascend",
fzyzcjy's avatar
fzyzcjy committed
171
    "nsa",
172
173
]

174
175
176
177
178
179
180
181
182
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
    "flashinfer",
    "fa3",
    "fa4",
    "flashmla",
    "cutlass_mla",
    "trtllm_mla",
]

183
184
185
186
187
188
189

def add_mla_attention_backend(backend_name):
    if backend_name not in MLA_ATTENTION_BACKENDS:
        MLA_ATTENTION_BACKENDS.append(backend_name)
        logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")


190
191
192
193
194
195
196
197
def add_chunked_prefix_cache_attention_backend(backend_name):
    if backend_name not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS:
        CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS.append(backend_name)
        logger.info(
            f"Added {backend_name} to CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS."
        )


198
_is_hip = is_hip()
199
_is_npu = is_npu()
200
_is_cpu_amx_available = cpu_has_amx_support()
201
_is_xpu_xmx_available = xpu_has_xmx_support()
202

Lianmin Zheng's avatar
Lianmin Zheng committed
203
# Use a small KV cache pool size for tests in CI
204
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
205
206

# Detect stragger ranks in model loading
207
208
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

209
210
211
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3

Lianmin Zheng's avatar
Lianmin Zheng committed
212
213
logger = logging.getLogger(__name__)

214
215
216
217
218
219
220
if _is_npu:
    import torch_npu

    torch.npu.config.allow_internal_format = True
    torch_npu.npu.set_compile_mode(jit_compile=False)


221
222
223
224
225
226
227
228
229
230
231
232
233
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
234
class ModelRunner:
235
236
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
237
238
    def __init__(
        self,
239
        model_config: ModelConfig,
240
241
242
243
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
Cheng Wan's avatar
Cheng Wan committed
244
245
        moe_ep_rank: int,
        moe_ep_size: int,
246
247
        pp_rank: int,
        pp_size: int,
248
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
249
        server_args: ServerArgs,
fzyzcjy's avatar
fzyzcjy committed
250
        dp_rank: Optional[int] = None,
251
        is_draft_worker: bool = False,
252
        req_to_token_pool: Optional[ReqToTokenPool] = None,
253
        token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
254
    ):
255
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
256
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
257
        self.device = server_args.device
258
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Cheng Wan's avatar
Cheng Wan committed
261
262
        self.moe_ep_rank = moe_ep_rank
        self.moe_ep_size = moe_ep_size
263
        self.dp_size = server_args.dp_size
264
265
        self.pp_rank = pp_rank
        self.pp_size = pp_size
266
        self.model_config = model_config
Zhang, Liangang's avatar
Zhang, Liangang committed
267
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
268
        self.server_args = server_args
269
        self.is_draft_worker = is_draft_worker
270
271
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
272
273
274
        self.is_multimodal_chunked_prefill_supported = (
            model_config.is_multimodal_chunked_prefill_supported
        )
275
276
277
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
278
        self.page_size = server_args.page_size
279
280
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
tarinkk's avatar
tarinkk committed
281
        self.is_hybrid = model_config.is_hybrid
Baizhou Zhang's avatar
Baizhou Zhang committed
282
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
283
        self.attention_chunk_size = model_config.attention_chunk_size
284
        self.forward_pass_id = 0
285
        self.init_new_workspace = False
286

Lianmin Zheng's avatar
Lianmin Zheng committed
287
        # Apply the rank zero filter to logger
288
289
        if server_args.show_time_cost:
            enable_show_time_cost()
290

Lianmin Zheng's avatar
Lianmin Zheng committed
291
292
293
        # Model-specific adjustment
        self.model_specific_adjustment()

294
295
296
297
298
299
        # Set the global server_args in the scheduler process
        set_global_server_args_for_scheduler(server_args)
        global_server_args = get_global_server_args()

        # FIXME: hacky set `use_mla_backend`
        global_server_args.use_mla_backend = self.use_mla_backend
Lianmin Zheng's avatar
Lianmin Zheng committed
300

301
302
303
304
        # Init OpenMP threads binding for CPU
        if self.device == "cpu":
            self.init_threads_binding()

305
        # Get memory before model loading
306
        min_per_gpu_memory = self.init_torch_distributed()
307

308
        # CPU offload
fzyzcjy's avatar
fzyzcjy committed
309
        set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
310

fzyzcjy's avatar
fzyzcjy committed
311
312
313
        if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
            slow_rank_detector.execute()

314
        # Update deep gemm configure
315
316
        if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
            deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
317

Lianmin Zheng's avatar
Lianmin Zheng committed
318
        # Initialize the model runner
319
320
        self.initialize(min_per_gpu_memory)

Lianmin Zheng's avatar
Lianmin Zheng committed
321
        # Temporary cached values
322
323
324
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
325
326

        # For weight updates
327
        self._model_update_group = {}
328
        self._weights_send_group = {}
329

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        if (
            self.server_args.enable_piecewise_cuda_graph
            and self.can_run_piecewise_cuda_graph()
        ):
            self.attention_layers = []
            for layer in self.model.model.layers:
                if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
                    self.attention_layers.append(layer.self_attn.attn)
            if len(self.attention_layers) < self.model_config.num_hidden_layers:
                # TODO(yuwei): support Non-Standard GQA
                log_info_on_rank0(
                    logger,
                    "Disable piecewise CUDA graph because some layers do not apply Standard GQA",
                )
                self.piecewise_cuda_graph_runner = None
            else:
                self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
        else:
            self.piecewise_cuda_graph_runner = None

350
351
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
352

353
354
355
356
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

357
358
        if not self.is_draft_worker:
            set_global_expert_location_metadata(
359
360
361
362
363
                compute_initial_expert_location_metadata(
                    server_args=server_args,
                    model_config=self.model_config,
                    moe_ep_rank=self.moe_ep_rank,
                )
364
365
366
367
368
            )
            if self.tp_rank == 0 and get_bool_env_var(
                "SGLANG_LOG_EXPERT_LOCATION_METADATA"
            ):
                logger.info(
369
                    f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
370
371
372
373
374
375
376
377
378
379
                )

            set_global_expert_distribution_recorder(
                ExpertDistributionRecorder.init_new(
                    server_args,
                    get_global_expert_location_metadata(),
                    rank=self.tp_rank,
                )
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
380
        # Expert parallelism
381
382
383
384
385
        self.eplb_manager = (
            EPLBManager(self)
            if self.server_args.enable_eplb and (not self.is_draft_worker)
            else None
        )
386
        self.expert_location_updater = ExpertLocationUpdater()
387

388
389
390
391
392
        (
            ElasticEPStateManager.init(self.server_args)
            if self.server_args.elastic_ep_backend
            else None
        )
393
        # Load the model
394
        self.sampler = Sampler()
395
        self.load_model()
396

397
        # Check if the model is using hybrid SWA
Hanming Lu's avatar
Hanming Lu committed
398
399
400
401
402
403
404
405
406
        if (
            not self.server_args.disable_hybrid_swa_memory
            and self.sliding_window_size is not None
            and self.sliding_window_size > 0
        ):
            architectures = self.model_config.hf_config.architectures
            if architectures and not any("Llama4" in arch for arch in architectures):
                self.is_hybrid = self.model_config.is_hybrid = True

407
        if config := self.mamba2_config:
408
409
            class_name = config.__class__.__name__
            logger.warning(f"{class_name} model detected, disable radix cache")
Yi Zhang's avatar
Yi Zhang committed
410
411
            self.server_args.disable_radix_cache = True

412
413
414
415
416
417
418
        # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
        # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
        # determine the number of layers.
        model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
        model_num_layers = (
            self.model_config.num_nextn_predict_layers
            if self.is_draft_worker and model_has_mtp_layers
419
420
421
422
            else max(
                self.model_config.num_hidden_layers,
                self.model_config.num_attention_layers,
            )
423
        )
424
425
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(self.model, "end_layer", model_num_layers)
426
        self.num_effective_layers = self.end_layer - self.start_layer
427
428
429
430
431
432
433
        assert (
            (not model_has_mtp_layers)
            or (self.spec_algorithm.is_none())
            or (
                (not self.spec_algorithm.is_none())
                and (self.num_effective_layers == model_num_layers)
            )
434
        ), "PP is not compatible with MTP models."
435

436
        # Apply torchao quantization
437
438
439
440
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
441
                self.model, get_global_server_args().torchao_config
442
            )
443

444
        # Apply torch TP if the model supports it
445
446
447
448
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

449
        # Init lora
450
        if server_args.enable_lora:
451
            self.init_lora_manager()
452

453
454
455
456
457
458
459
460
        # Init Double Sparsity
        if server_args.enable_double_sparsity:
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

461
462
        # Enable batch invariant mode
        if server_args.enable_deterministic_inference:
463
            from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
464
465
466

            enable_batch_invariant_mode()

467
        # Init memory pool and attention backends
468
469
        self.init_memory_pool(
            min_per_gpu_memory,
470
            server_args.max_running_requests,
471
472
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
473
474
475
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
476
            self.init_device_graphs()
477
        elif self.device in ["npu", "cpu"]:
478
479
            self.init_attention_backend()
            self.init_device_graphs()
Zhang, Liangang's avatar
Zhang, Liangang committed
480
        else:
481
            self.graph_runner = None
482
            self.graph_mem_usage = 0
Zhang, Liangang's avatar
Zhang, Liangang committed
483
            self.init_attention_backend()
484

James Liu's avatar
James Liu committed
485
486
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
lukec's avatar
lukec committed
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
            # load draft config
            draft_model_config = ModelConfig.from_server_args(
                server_args,
                model_path=(server_args.speculative_draft_model_path),
                is_draft_model=True,
            )

            try:
                # get the aux layer from draft model config
                eagle_config = getattr(
                    draft_model_config.hf_config, "eagle_config", None
                )
                eagle_aux_hidden_state_layer_ids = eagle_config[
                    "eagle_aux_hidden_state_layer_ids"
                ]
            except:
                # if there is no aux layer, set to None
                eagle_aux_hidden_state_layer_ids = None

            self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
James Liu's avatar
James Liu committed
507

508
509
510
511
    def model_specific_adjustment(self):
        server_args = self.server_args

        if server_args.enable_double_sparsity:
512
513
514
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
515
516
517
518
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True

        if self.is_multimodal:
519
520
521
            if not self.is_multimodal_chunked_prefill_supported:
                server_args.chunked_prefill_size = -1
                logger.info(
522
                    f"Automatically turn off --chunked-prefill-size as it is not supported for "
523
524
                    f"{self.model_config.hf_config.model_type}"
                )
525

526
527
528
529
530
        if (
            not self.use_mla_backend
            or server_args.attention_backend
            not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS
        ):
531
            server_args.disable_chunked_prefix_cache = True
532

533
        if not server_args.disable_chunked_prefix_cache:
534
            log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
535

536
537
538
539
540
        if self.model_config.hf_config.model_type == "qwen3_vl_moe":
            if (
                quantization_config := getattr(
                    self.model_config.hf_config, "quantization_config", None
                )
541
            ) is not None and "weight_block_size" in quantization_config:
542
                weight_block_size_n = quantization_config["weight_block_size"][0]
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558

                if self.tp_size % self.moe_ep_size != 0:
                    raise ValueError(
                        f"tp_size {self.tp_size} must be divisible by moe_ep_size {self.moe_ep_size}"
                    )
                moe_tp_size = self.tp_size // self.moe_ep_size

                moe_intermediate_size = (
                    self.model_config.hf_text_config.moe_intermediate_size
                )
                if moe_intermediate_size % moe_tp_size != 0:
                    raise ValueError(
                        f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})."
                    )

                if (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0:
559
                    raise ValueError(
560
561
562
                        f"For qwen3-vl-fp8 models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 "
                        f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size}). "
                        f"You can fix this by setting arguments `--tp-size` and `--ep-size` correctly."
563
564
                    )

565
    def init_torch_distributed(self):
566
        logger.info("Init torch distributed begin.")
567

568
569
570
571
572
573
574
575
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
576
        if self.device == "cuda":
577
578
579
580
581
582
583
584
585
586
587
588
            if self.server_args.elastic_ep_backend == "mooncake":
                backend = "mooncake"
                if self.server_args.mooncake_ib_device:
                    mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
                    try:
                        from mooncake import ep as mooncake_ep

                        mooncake_ep.set_device_filter(mooncake_ib_device)
                    except:
                        pass  # A warning will be raised in `init_distributed_environment`
            else:
                backend = "nccl"
589
        elif self.device == "xpu":
590
            backend = "xccl"
591
592
        elif self.device == "hpu":
            backend = "hccl"
593
594
        elif self.device == "cpu":
            backend = "gloo"
595
596
        elif self.device == "npu":
            backend = "hccl"
597

598
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
599
        if not self.server_args.enable_p2p_check:
600
601
            monkey_patch_p2p_access_check()

602
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
603
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
604
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
605
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
606
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
607
        set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
608
        set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
609
610

        if not self.is_draft_worker:
611
612
613
614
            if self.device == "cpu":
                if _is_cpu_amx_available:
                    # Bind OpenMP threads to CPU cores
                    torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
615
616
617
618

                    # Set local size to hint SGLang to use shared memory based AllReduce
                    os.environ["LOCAL_SIZE"] = str(self.tp_size)
                    torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
619
620
621
622
623

                    @torch.library.register_fake("sgl_kernel::shm_allgather")
                    def _(data, dim):
                        return torch.cat([data] * self.tp_size, dim=dim)

624
625
                else:
                    logger.warning(
626
                        "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
627
628
                    )

Mick's avatar
Mick committed
629
            # Only initialize the distributed environment on the target model worker.
630
631
            init_distributed_environment(
                backend=backend,
632
633
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
634
635
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
636
                timeout=self.server_args.dist_timeout,
637
            )
638
639
640
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
Cheng Wan's avatar
Cheng Wan committed
641
                expert_model_parallel_size=self.moe_ep_size,
642
                duplicate_tp_group=self.server_args.enable_pdmux,
643
                torch_compile=self.server_args.enable_piecewise_cuda_graph,
644
            )
645
            initialize_dp_attention(
646
647
                server_args=self.server_args,
                model_config=self.model_config,
648
            )
649

650
        min_per_gpu_memory = get_available_gpu_memory(
651
652
653
654
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
655
        )
656
        self.tp_group = get_tp_group()
657
        self.pp_group = get_pp_group()
658
        self.attention_tp_group = get_attention_tp_group()
659

660
        # Check memory for tensor parallelism
661
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
662
        if self.tp_size > 1 and not self.is_draft_worker:
663
            if min_per_gpu_memory < local_gpu_memory * 0.9:
664
665
666
667
668
669
670
671
672
673
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
674

675
676
677
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
678
        return min_per_gpu_memory
679

Lianmin Zheng's avatar
Lianmin Zheng committed
680
    def load_model(self):
681
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
682
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
683
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
684
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
685
686

        # This can reduce thread conflicts and speed up weight loading.
687
688
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
689
690
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
691
692
693
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
694
                self.server_args.dtype = "float16"
695
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
696
697
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
698

699
700
        set_cuda_arch()

701
        # Prepare the model config
702
703
704
705
706
707
708
709
710
711
        from sglang.srt.configs.modelopt_config import ModelOptConfig

        modelopt_config = ModelOptConfig(
            quant=self.server_args.modelopt_quant,
            checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
            checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
            export_path=self.server_args.modelopt_export_path,
            quantize_and_serve=self.server_args.quantize_and_serve,
        )

712
713
714
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
715
            model_loader_extra_config=self.server_args.model_loader_extra_config,
716
717
718
719
            tp_rank=self.tp_rank,
            remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
            remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
            remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
720
            modelopt_config=modelopt_config,
721
        )
722
723
724
725
        if self.device == "cpu":
            self.model_config = adjust_config_with_unaligned_cpu_tp(
                self.model_config, self.load_config, self.tp_size
            )
726

727
728
729
730
731
732
733
734
735
736
737
738
739
740
        if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
            if self.tp_rank == 0:
                instance_ip = socket.gethostbyname(socket.gethostname())
                t = threading.Thread(
                    target=trigger_init_weights_send_group_for_remote_instance_request,
                    args=(
                        self.server_args.remote_instance_weight_loader_seed_instance_ip,
                        self.server_args.remote_instance_weight_loader_seed_instance_service_port,
                        self.server_args.remote_instance_weight_loader_send_weights_group_ports,
                        instance_ip,
                    ),
                )
                t.start()

741
        # Load the model
742
743
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
744

745
746
747
748
        with self.memory_saver_adapter.region(
            GPU_MEMORY_TYPE_WEIGHTS,
            enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
        ):
749
750
751
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
752
                device_config=DeviceConfig(self.device, self.gpu_id),
753
            )
754
        monkey_patch_vllm_parallel_state(reverse=True)
755

756
757
        get_offloader().post_init()

bjmsong's avatar
bjmsong committed
758
759
760
761
762
763
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
764
765
766
767
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
768
769
770
771
772
773
774
775
776
777
778
779
780
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

781
        # Parse other args
Hanming Lu's avatar
Hanming Lu committed
782
783
784
785
786
        self.sliding_window_size = None
        if hasattr(self.model, "get_attention_sliding_window_size"):
            self.sliding_window_size = self.model.get_attention_sliding_window_size()
        elif self.model_config.attention_chunk_size is not None:
            self.sliding_window_size = self.model_config.attention_chunk_size
787
            logger.info(
Hanming Lu's avatar
Hanming Lu committed
788
789
790
                f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
            )

791
        self.dtype = self.model_config.dtype
792

793
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
794
        self.weight_load_mem_usage = before_avail_memory - after_avail_memory
795
        logger.info(
796
            f"Load weight end. "
797
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
798
            f"dtype={self.dtype}, "
799
            f"avail mem={after_avail_memory:.2f} GB, "
800
            f"mem usage={self.weight_load_mem_usage:.2f} GB."
801
        )
802
803
804
805
806
807
808
809
810
        if self.server_args.debug_tensor_dump_output_folder is not None:
            register_forward_hook_for_model(
                self.model,
                self.server_args.debug_tensor_dump_output_folder,
                self.server_args.debug_tensor_dump_layers,
                self.tp_size,
                self.tp_rank,
                self.pp_rank,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
811

812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
        if self.server_args.elastic_ep_backend == "mooncake":
            # Mooncake does not support `monitored_barrier`
            dist.barrier(group=get_tp_group().cpu_group)
        else:
            # Handle the case where some ranks do not finish loading.
            try:
                dist.monitored_barrier(
                    group=get_tp_group().cpu_group,
                    timeout=datetime.timedelta(
                        seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S
                    ),
                    wait_all_ranks=True,
                )
            except RuntimeError:
                raise ValueError(
                    f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
                ) from None
829

830
    def update_expert_location(
831
832
833
        self,
        new_expert_location_metadata: ExpertLocationMetadata,
        update_layer_ids: List[int],
834
    ):
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        if ElasticEPStateManager.instance() is not None:
            # TODO: refactor the weights update when elastic ep
            old_expert_location_metadata = get_global_expert_location_metadata()
            assert old_expert_location_metadata is not None
            old_expert_location_metadata.update(
                new_expert_location_metadata,
                update_layer_ids=update_layer_ids,
            )
            self.update_weights_from_disk(
                self.server_args.model_path,
                self.server_args.load_format,
                lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
            )
        else:
            self.expert_location_updater.update(
                self.model.routed_experts_weights_of_layer,
                new_expert_location_metadata,
                update_layer_ids=update_layer_ids,
                nnodes=self.server_args.nnodes,
                rank=self.tp_rank,
            )
856

857
    def update_weights_from_disk(
858
859
860
861
        self,
        model_path: str,
        load_format: str,
        weight_name_filter: Optional[Callable[[str], bool]] = None,
862
863
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
864
        logger.info(
Chayenne's avatar
Chayenne committed
865
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
866
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
867
868
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
869
        target_device = torch.device(self.device)
870
        self.model_config.model_path = model_path
871
872
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
873
        # Only support DefaultModelLoader for now
874
        loader = get_model_loader(load_config, self.model_config)
875
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
876
877
            message = f"Failed to get model loader: {loader}."
            return False, message
878
879
880

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
881
                DefaultModelLoader.Source.init_new(config, self.model)
882
            )
883
884
885
886
887
            if weight_name_filter is not None:
                iter = (
                    (name, weight) for name, weight in iter if weight_name_filter(name)
                )

888
889
890
            return iter

        def model_load_weights(model, iter):
891
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
892
893
            return model

894
        with set_default_torch_dtype(self.model_config.dtype):
895
            try:
896
                iter = get_weight_iter(self.model_config)
897
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
898
                message = f"Failed to get weights iterator: {e}."
899
900
901
902
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
903
904
905
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
906
907
                del iter
                gc.collect()
908
                iter = get_weight_iter(self.model_config)
909
910
911
912
913
914
915
916
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

917
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
918
        return True, "Succeeded to update model weights."
919

920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
    def init_weights_send_group_for_remote_instance(
        self,
        master_address,
        ports,
        group_rank,
        world_size,
        group_name,
        backend="nccl",
    ):
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        ports_list = ports.split(",")
        assert (
            len(ports_list) == self.tp_size
        ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
        group_port = ports_list[self.tp_rank]
        group_name = f"{group_name}_{group_port}_{self.tp_rank}"

        logger.info(
            f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
            f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
        )

        torch.cuda.empty_cache()
        success = False
        message = ""
        try:
            self._weights_send_group[group_name] = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{group_port}",
                world_size=world_size,
                rank=group_rank,
                group_name=group_name,
                device_id=torch.device("cuda", self.gpu_id),
            )
            dist.barrier(group=self._weights_send_group[group_name])
            success = True
            message = (
                f"Succeeded to init group through {master_address}:{group_port} group."
            )
        except Exception as e:
            message = f"Failed to init group: {e}."
            logger.error(message)

        torch.cuda.empty_cache()
        return success, message

    def send_weights_to_remote_instance(
        self,
        master_address,
        ports,
        group_name,
    ):
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        ports_list = ports.split(",")
        assert (
            len(ports_list) == self.tp_size
        ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
        group_port = ports_list[self.tp_rank]
        group_name = f"{group_name}_{group_port}_{self.tp_rank}"

        if self._weights_send_group[group_name] is not None:
            send_group = self._weights_send_group[group_name]
        else:
            message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
            logger.error(message)
            return False, message

        torch.cuda.empty_cache()
        success = False
        message = ""
        try:
            for _, weights in self.model.named_parameters():
                torch.distributed.broadcast(
                    weights,
                    src=0,
                    group=send_group,
                )
            success = True
            message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
        except Exception as e:
            message = f"Failed to send weights: {e}."
            logger.error(message)

        # destroy the process group after sending weights
        del self._weights_send_group[group_name]
        torch.distributed.distributed_c10d.destroy_process_group(send_group)
        torch.cuda.empty_cache()
        return success, message

1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
1045
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
1046
1047
1048
        )

        try:
1049
            self._model_update_group[group_name] = init_custom_process_group(
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
    def destroy_weights_update_group(self, group_name):
        try:
            if group_name in self._model_update_group:
                pg = self._model_update_group.pop(group_name)
                torch.distributed.destroy_process_group(pg)
                return True, "Succeeded to destroy custom process group."
            else:
                return False, "The group to be destroyed does not exist."
        except Exception as e:
            message = f"Failed to destroy custom process group: {e}."
            logger.error(message)
            return False, message

1075
    def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """

1086
1087
1088
1089
        assert group_name in self._model_update_group, (
            f"Group {group_name} not in {list(self._model_update_group.keys())}. "
            "Please call `init_weights_update_group` first."
        )
1090
1091

        try:
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
            weights = []
            handles = []
            for name, dtype, shape in zip(names, dtypes, shapes):
                target_dtype = (
                    dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
                )
                weight = torch.empty(shape, dtype=target_dtype, device=self.device)
                handles.append(
                    torch.distributed.broadcast(
                        weight,
                        src=0,
                        group=self._model_update_group[group_name],
                        async_op=True,
                    )
                )
                weights.append((name, weight))
            for handle in handles:
                handle.wait()

            self.model.load_weights(weights)
1112
            return True, "Succeeded to update parameter online."
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

1123
1124
1125
1126
1127
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
1128
        monkey_patch_torch_reductions()
1129
1130
1131
1132
1133
1134
        if load_format == "flattened_bucket":
            # Handle flattened bucket format
            return self._update_weights_from_flattened_bucket(
                flattened_tensor_bucket_dict=named_tensors
            )

1135
        # We need to get device after patch otherwise the device would be wrong
1136
1137
        self.device_module = torch.get_device_module(self.device)
        infered_device = self.device_module.current_device()
1138

1139
        named_tensors = [
1140
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
1141
1142
1143
1144
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
1145
1146
1147
        elif load_format in self.server_args.custom_weight_loader:
            custom_loader = dynamic_import(load_format)
            custom_loader(self.model, named_tensors)
1148
1149
1150
1151
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
1152
        return True, "Success"
1153

1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    def _update_weights_from_flattened_bucket(
        self,
        flattened_tensor_bucket_dict,
    ):
        """Handle flattened bucket format for weight updates"""
        flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
        metadata = flattened_tensor_bucket_dict["metadata"]

        # Convert metadata dict to our format
        converted_metadata = []
        for meta in metadata:
            converted_meta = FlattenedTensorMetadata(
                name=meta.name,
                shape=meta.shape,
                dtype=meta.dtype,
                start_idx=meta.start_idx,
                end_idx=meta.end_idx,
                numel=meta.numel,
            )
            converted_metadata.append(converted_meta)

        # Create bucket and reconstruct tensors
        bucket = FlattenedTensorBucket(
            flattened_tensor=flattened_tensor, metadata=converted_metadata
        )
        reconstructed_tensors = bucket.reconstruct_tensors()

        # Load the reconstructed tensors using the standard method
        self.model.load_weights(reconstructed_tensors)

        return True, "Success"

1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

1203
1204
1205
1206
1207
1208
1209
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
1210
            lora_backend=self.server_args.lora_backend,
1211
1212
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
1213
1214
            max_lora_rank=self.server_args.max_lora_rank,
            target_modules=self.server_args.lora_target_modules,
1215
            lora_paths=self.server_args.lora_paths,
1216
            server_args=self.server_args,
1217
        )
1218

1219
    def load_lora_adapter(self, lora_ref: LoRARef):
1220
1221
1222
        """Load a new lora adapter from disk or huggingface."""

        logger.info(
1223
            f"LoRA adapter loading starts: {lora_ref}. "
1224
1225
1226
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1227
        result = self.lora_manager.load_lora_adapter(lora_ref)
1228
1229

        logger.info(
1230
            f"LoRA adapter loading completes: {lora_ref}. "
1231
1232
1233
1234
1235
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result

1236
    def unload_lora_adapter(self, lora_ref: LoRARef):
1237
1238
1239
        """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""

        logger.info(
1240
            f"LoRA adapter unloading starts: {lora_ref}. "
1241
1242
1243
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1244
        result = self.lora_manager.unload_lora_adapter(lora_ref)
1245
1246

        logger.info(
1247
            f"LoRA adapter unloading completes: {lora_ref}. "
1248
1249
1250
1251
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result
1252

1253
    def profile_max_num_token(self, total_gpu_memory: int):
1254
        available_gpu_memory = get_available_gpu_memory(
1255
1256
1257
1258
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
1259
        )
1260
1261
1262
1263
1264
        if self.is_draft_worker:
            num_layers = getattr(
                self.model_config.hf_config,
                "num_nextn_predict_layers",
                self.num_effective_layers,
1265
            )
1266
1267
        elif config := self.mambaish_config:
            num_layers = len(config.full_attention_layer_ids)
1268
1269
1270
        else:
            num_layers = self.num_effective_layers
        if self.use_mla_backend:
1271
1272
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
1273
                * num_layers
1274
                * torch._utils._element_size(self.kv_cache_dtype)
1275
            )
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
            # Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
            if is_deepseek_nsa(self.model_config.hf_config):
                index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
                indexer_size_per_token = (
                    index_head_dim
                    + index_head_dim // NSATokenToKVPool.quant_block_size * 4
                )
                element_size = torch._utils._element_size(
                    NSATokenToKVPool.index_k_with_scale_buffer_dtype
                )
                cell_size += indexer_size_per_token * num_layers * element_size
1287
1288
        else:
            cell_size = (
1289
                self.model_config.get_num_kv_heads(get_attention_tp_size())
1290
                * self.model_config.head_dim
1291
                * num_layers
1292
                * 2
1293
                * torch._utils._element_size(self.kv_cache_dtype)
1294
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1295
1296
1297
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
1298
1299
        if self.mambaish_config is not None:
            rest_memory = self.handle_max_mamba_cache(rest_memory)
1300
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1301
1302
        return max_num_token

1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
    def handle_max_mamba_cache(self, total_rest_memory):
        config = self.mambaish_config
        server_args = self.server_args
        assert config is not None

        speculativa_ratio = (
            0
            if server_args.speculative_num_draft_tokens is None
            else server_args.speculative_num_draft_tokens
        )
        if (
            server_args.disable_radix_cache
            or config.mamba2_cache_params.mamba_cache_per_req == 0
        ):
            # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
            if server_args.max_mamba_cache_size is None:
                if server_args.max_running_requests is not None:
                    server_args.max_mamba_cache_size = server_args.max_running_requests
                else:
                    server_args.max_mamba_cache_size = 512
        else:
            # allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
            # solve the equations:
            # 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
            # 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
            mamba_state_memory_raw = (
                total_rest_memory
                * server_args.mamba_full_memory_ratio
                / (1 + server_args.mamba_full_memory_ratio)
            )
            # calculate the max_mamba_cache_size based on the given total mamba memory
            server_args.max_mamba_cache_size = int(
                (mamba_state_memory_raw * (1 << 30))
                // config.mamba2_cache_params.mamba_cache_per_req
                // (1 + speculativa_ratio)
            )

        if self.hybrid_gdn_config is not None:
            server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
                server_args.dp_size if server_args.enable_dp_attention else 1
            )
        mamba_state_memory = (
            server_args.max_mamba_cache_size
            * config.mamba2_cache_params.mamba_cache_per_req
            * (1 + speculativa_ratio)
            / (1 << 30)
        )
        return total_rest_memory - mamba_state_memory

Yi Zhang's avatar
Yi Zhang committed
1352
    @property
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
    def hybrid_gdn_config(self):
        config = self.model_config.hf_config
        if isinstance(config, Qwen3NextConfig):
            return config
        return None

    @property
    def mamba2_config(self):
        config = self.model_config.hf_config
        if isinstance(config, FalconH1Config | NemotronHConfig):
            return config
        return None

Ke Bao's avatar
Ke Bao committed
1366
1367
1368
1369
1370
1371
1372
    @property
    def kimi_linear_config(self):
        config = self.model_config.hf_config
        if isinstance(config, KimiLinearConfig):
            return config
        return None

1373
1374
    @property
    def mambaish_config(self):
Ke Bao's avatar
Ke Bao committed
1375
        return self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config
Yi Zhang's avatar
Yi Zhang committed
1376

tarinkk's avatar
tarinkk committed
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
    def set_num_token_hybrid(self):
        if (
            "Llama4ForConditionalGeneration"
            in self.model_config.hf_config.architectures
        ):
            temp_ratio = (
                (1 - self.is_hybrid)
                + self.is_hybrid
                * self.attention_chunk_size
                / self.model_config.context_len
            )
            self.swa_max_total_num_tokens = (
                4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.full_max_total_num_tokens = (
                4 * self.max_total_num_tokens
                - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.swa_max_total_num_tokens = int(
                self.swa_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.full_max_total_num_tokens = int(
                self.full_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens
        else:
Hanming Lu's avatar
Hanming Lu committed
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
            assert self.sliding_window_size is not None and self.sliding_window_size > 0
            full_attention_layer_ids = []
            swa_attention_layer_ids = []

            try:
                layers = self.model.model.layers
            except:
                try:
                    layers = self.model.language_model.model.layers
                except:
1417
1418
1419
1420
1421
                    try:
                        layers = self.model.language_model.layers
                    except:
                        self.is_hybrid = False
                        return
Hanming Lu's avatar
Hanming Lu committed
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456

            for layer in layers:
                if (
                    layer.self_attn.attn.sliding_window_size is None
                    or layer.self_attn.attn.sliding_window_size == -1
                ):
                    full_attention_layer_ids.append(layer.layer_id)
                else:
                    swa_attention_layer_ids.append(layer.layer_id)
            self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
            self.model_config.full_attention_layer_ids = full_attention_layer_ids

            # Algorithm:
            # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
            # - Find total # of tokens available across layers.
            # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
            total_tokens = (
                self.max_total_num_tokens * self.model_config.num_hidden_layers
            )
            full_layers_num = len(full_attention_layer_ids)
            swa_layers_num = len(swa_attention_layer_ids)
            swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio

            # Solve the equations:
            # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
            # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
            denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
            self.full_max_total_num_tokens = int(total_tokens / denominator)
            self.swa_max_total_num_tokens = int(
                self.full_max_total_num_tokens * swa_full_tokens_ratio
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens

            logger.info(
                f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
tarinkk's avatar
tarinkk committed
1457
1458
            )

1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
    def can_run_piecewise_cuda_graph(self):
        if self.server_args.disable_cuda_graph:
            log_info_on_rank0(
                logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
            )
            return False
        if self.server_args.enable_torch_compile:
            log_info_on_rank0(
                logger,
                "Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
            )
            return False
        if self.pp_size > 1:
            # TODO(yuwei): support PP
            log_info_on_rank0(
                logger,
                "Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
            )
            return False
        return True

1480
    def init_memory_pool(
1481
1482
        self,
        total_gpu_memory: int,
1483
1484
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
1485
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1486
        # Determine the kv cache dtype
1487
        if self.server_args.kv_cache_dtype == "auto":
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
            quant_config = getattr(self.model, "quant_config", None)
            kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
            if (
                isinstance(kv_cache_quant_algo, str)
                and kv_cache_quant_algo.upper() == "FP8"
            ):
                if _is_hip:
                    self.kv_cache_dtype = torch.float8_e4m3fnuz
                else:
                    self.kv_cache_dtype = torch.float8_e4m3fn
            else:
                self.kv_cache_dtype = self.dtype
1500
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1501
            if _is_hip:  # Using natively supported format
HAI's avatar
HAI committed
1502
1503
1504
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
1505
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
1506
1507
1508
            if _is_hip:  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e4m3fnuz
            else:
bjmsong's avatar
bjmsong committed
1509
                self.kv_cache_dtype = torch.float8_e4m3fn
1510
1511
        elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
            self.kv_cache_dtype = torch.bfloat16
1512
1513
1514
1515
1516
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

1517
1518
        log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")

1519
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
Lianmin Zheng's avatar
Lianmin Zheng committed
1520
1521
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )
1533

1534
        if self.mambaish_config is not None:
1535
1536
1537
1538
1539
1540
1541
1542
            ratio = (
                MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
                if not self.server_args.disable_radix_cache
                else 1
            )
            max_num_reqs = min(
                max_num_reqs, self.server_args.max_mamba_cache_size // ratio
            )
1543

1544
        if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1545
1546
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1547
                max_num_reqs = self.server_args.max_num_reqs
1548
            else:
1549
1550
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
1551
1552
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
1553
1554
1555
1556
1557
1558
1559
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
1560
1561
                    + 100
                )
1562
1563
1564
1565
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
1566

1567
1568
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
1569
                logging.warning(
1570
1571
1572
1573
1574
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
1575

1576
1577
1578
1579
1580
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
        # different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
        if self.pp_size > 1:
            tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
            torch.distributed.all_reduce(
                tensor,
                op=torch.distributed.ReduceOp.MIN,
                group=get_world_group().cpu_group,
            )
            self.max_total_num_tokens = tensor.item()

tarinkk's avatar
tarinkk committed
1591
1592
1593
1594
        # create token size for hybrid cache
        if self.is_hybrid:
            self.set_num_token_hybrid()

1595
        if self.max_total_num_tokens <= 0:
1596
            raise RuntimeError(
1597
1598
                f"Not enough memory. Please try to increase --mem-fraction-static. "
                f"Current value: {self.server_args.mem_fraction_static=}"
1599
            )
1600

Lianmin Zheng's avatar
Lianmin Zheng committed
1601
        # Initialize req_to_token_pool
1602
        if self.req_to_token_pool is None:
1603
1604
1605
1606
1607
            # FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
            extra_max_context_len = 4
            if self.server_args.speculative_num_draft_tokens is not None:
                extra_max_context_len += self.server_args.speculative_num_draft_tokens

Byron Hsu's avatar
Byron Hsu committed
1608
            if self.server_args.disaggregation_mode == "decode":
1609
1610
1611
1612
                from sglang.srt.disaggregation.decode import (
                    DecodeReqToTokenPool,
                    HybridMambaDecodeReqToTokenPool,
                )
Byron Hsu's avatar
Byron Hsu committed
1613
1614
1615
1616

                # subscribe memory for pre-allocated requests
                # if max_num_reqs <= 32, we pre-allocate 2x requests
                pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
                if config := self.mambaish_config:
                    self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
                        size=max_num_reqs,
                        max_context_len=self.model_config.context_len
                        + extra_max_context_len,
                        device=self.device,
                        enable_memory_saver=self.server_args.enable_memory_saver,
                        cache_params=config.mamba2_cache_params,
                        speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
                        pre_alloc_size=pre_alloc_size,
                    )
                else:
                    self.req_to_token_pool = DecodeReqToTokenPool(
                        size=max_num_reqs,
                        max_context_len=self.model_config.context_len
                        + extra_max_context_len,
                        device=self.device,
                        enable_memory_saver=self.server_args.enable_memory_saver,
                        pre_alloc_size=pre_alloc_size,
                    )
1637
            elif config := self.mambaish_config:
Yi Zhang's avatar
Yi Zhang committed
1638
1639
                self.req_to_token_pool = HybridReqToTokenPool(
                    size=max_num_reqs,
1640
                    mamba_size=self.server_args.max_mamba_cache_size,
Yi Zhang's avatar
Yi Zhang committed
1641
1642
1643
1644
                    max_context_len=self.model_config.context_len
                    + extra_max_context_len,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
1645
                    cache_params=config.mamba2_cache_params,
Yi Zhang's avatar
Yi Zhang committed
1646
1647
                    speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
                )
Byron Hsu's avatar
Byron Hsu committed
1648
1649
1650
            else:
                self.req_to_token_pool = ReqToTokenPool(
                    size=max_num_reqs,
1651
1652
                    max_context_len=self.model_config.context_len
                    + extra_max_context_len,
Byron Hsu's avatar
Byron Hsu committed
1653
1654
1655
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
1656
1657
1658
1659
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

Lianmin Zheng's avatar
Lianmin Zheng committed
1660
        # Initialize token_to_kv_pool
fzyzcjy's avatar
fzyzcjy committed
1661
        is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
1662
1663
1664
1665
1666
1667
1668
1669
        if self.server_args.attention_backend == "ascend":
            if self.use_mla_backend:
                self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    kv_lora_rank=self.model_config.kv_lora_rank,
                    qk_rope_head_dim=self.model_config.qk_rope_head_dim,
fzyzcjy's avatar
fzyzcjy committed
1670
                    index_head_dim=self.model_config.index_head_dim,
Lianmin Zheng's avatar
Lianmin Zheng committed
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
                    layer_num=self.num_effective_layers,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
                )
            else:
                self.token_to_kv_pool = AscendTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
Makcum888e's avatar
Makcum888e committed
1686
                    layer_num=self.num_effective_layers,
Lianmin Zheng's avatar
Lianmin Zheng committed
1687
1688
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
Makcum888e's avatar
Makcum888e committed
1689
1690
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
Lianmin Zheng's avatar
Lianmin Zheng committed
1691
                )
fzyzcjy's avatar
fzyzcjy committed
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
        elif self.use_mla_backend and is_nsa_model:
            self.token_to_kv_pool = NSATokenToKVPool(
                self.max_total_num_tokens,
                page_size=self.page_size,
                dtype=self.kv_cache_dtype,
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.num_effective_layers,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
                start_layer=self.start_layer,
                end_layer=self.end_layer,
                index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
            )
Ke Bao's avatar
Ke Bao committed
1706
        elif self.use_mla_backend and not self.mambaish_config:
fzyzcjy's avatar
fzyzcjy committed
1707
            assert not is_nsa_model
1708
1709
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1710
                page_size=self.page_size,
1711
                dtype=self.kv_cache_dtype,
1712
1713
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1714
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
1715
                device=self.device,
1716
                enable_memory_saver=self.server_args.enable_memory_saver,
1717
1718
                start_layer=self.start_layer,
                end_layer=self.end_layer,
1719
            )
Shuo Yang's avatar
Shuo Yang committed
1720
1721
1722
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1723
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
1724
                dtype=self.kv_cache_dtype,
1725
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
1726
                head_dim=self.model_config.head_dim,
1727
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
1728
1729
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
1730
                enable_memory_saver=self.server_args.enable_memory_saver,
1731
1732
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
1733
            )
1734
        else:
tarinkk's avatar
tarinkk committed
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
            if self.is_hybrid:
                self.token_to_kv_pool = SWAKVPool(
                    size=self.full_max_total_num_tokens,
                    size_swa=self.swa_max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
                    full_attention_layer_ids=self.model_config.full_attention_layer_ids,
                    enable_kvcache_transpose=False,
                    device=self.device,
                )
1749
            elif config := self.mambaish_config:
Ke Bao's avatar
Ke Bao committed
1750
1751
1752
1753
1754
1755
                extra_args = {}
                if self.use_mla_backend:
                    extra_args = {
                        "kv_lora_rank": self.model_config.kv_lora_rank,
                        "qk_rope_head_dim": self.model_config.qk_rope_head_dim,
                    }
Yi Zhang's avatar
Yi Zhang committed
1756
                self.token_to_kv_pool = HybridLinearKVPool(
1757
                    page_size=self.page_size,
Yi Zhang's avatar
Yi Zhang committed
1758
1759
1760
1761
1762
1763
1764
1765
                    size=self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    # if draft worker, we only need 1 attention layer's kv pool
                    full_attention_layer_ids=(
1766
                        [0] if self.is_draft_worker else config.full_attention_layer_ids
Yi Zhang's avatar
Yi Zhang committed
1767
1768
1769
                    ),
                    enable_kvcache_transpose=False,
                    device=self.device,
1770
                    mamba_pool=self.req_to_token_pool.mamba_pool,
Ke Bao's avatar
Ke Bao committed
1771
1772
                    use_mla=self.use_mla_backend,
                    **extra_args,
Yi Zhang's avatar
Yi Zhang committed
1773
                )
tarinkk's avatar
tarinkk committed
1774
1775
            else:
                self.token_to_kv_pool = MHATokenToKVPool(
Lianmin Zheng's avatar
Lianmin Zheng committed
1776
                    self.max_total_num_tokens,
tarinkk's avatar
tarinkk committed
1777
                    page_size=self.page_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
1778
                    dtype=self.kv_cache_dtype,
tarinkk's avatar
tarinkk committed
1779
1780
1781
1782
1783
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    layer_num=self.num_effective_layers,
Lianmin Zheng's avatar
Lianmin Zheng committed
1784
                    device=self.device,
tarinkk's avatar
tarinkk committed
1785
1786
1787
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
1788
                    enable_alt_stream=not self.server_args.enable_pdmux,
1789
1790
1791
                    enable_kv_cache_copy=(
                        self.server_args.speculative_algorithm is not None
                    ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1792
                )
tarinkk's avatar
tarinkk committed
1793

Lianmin Zheng's avatar
Lianmin Zheng committed
1794
        # Initialize token_to_kv_pool_allocator
Lianmin Zheng's avatar
Lianmin Zheng committed
1795
        need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
tarinkk's avatar
tarinkk committed
1796
        if self.token_to_kv_pool_allocator is None:
1797
            if _is_npu and (
1798
1799
                self.server_args.attention_backend == "ascend"
                or self.hybrid_gdn_config is not None
1800
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1801
1802
1803
1804
1805
1806
1807
1808
                self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                    need_sort=need_sort,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1809
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
                if self.page_size == 1:
                    if self.is_hybrid:
                        self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
                            self.full_max_total_num_tokens,
                            self.swa_max_total_num_tokens,
                            dtype=self.kv_cache_dtype,
                            device=self.device,
                            kvcache=self.token_to_kv_pool,
                            need_sort=need_sort,
                        )
                    else:
                        self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                            self.max_total_num_tokens,
                            dtype=self.kv_cache_dtype,
                            device=self.device,
                            kvcache=self.token_to_kv_pool,
                            need_sort=need_sort,
                        )
1828
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1829
1830
                    assert not self.is_hybrid
                    self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1831
1832
1833
1834
1835
                        self.max_total_num_tokens,
                        page_size=self.page_size,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
1836
                        need_sort=need_sort,
1837
                    )
1838
1839
1840
        else:
            assert self.is_draft_worker

1841
        logger.info(
1842
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
1843
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
1844
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1845

Lianmin Zheng's avatar
Lianmin Zheng committed
1846
1847
1848
1849
1850
1851
1852
1853
1854
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

1855
1856
    def init_attention_backend(self):
        """Init attention kernel backend."""
1857
1858
1859
1860
1861
1862
1863
        if self.server_args.enable_pdmux:
            self.attn_backend = self._get_attention_backend(init_new_workspace=True)
            self.decode_attn_backend_group = []
            for _ in range(self.server_args.sm_group_num):
                self.decode_attn_backend_group.append(self._get_attention_backend())
            self.decode_attn_backend = self.decode_attn_backend_group[0]
        elif self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1864
1865
1866
1867
            self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
        else:
            self.attn_backend = self._get_attention_backend()

1868
    def _get_attention_backend(self, init_new_workspace: bool = False):
1869
        """Init attention kernel backend."""
1870
1871
        self.prefill_attention_backend_str, self.decode_attention_backend_str = (
            self.server_args.get_attention_backends()
1872
        )
1873

1874
1875
1876
1877
1878
1879
        if self.decode_attention_backend_str != self.prefill_attention_backend_str:
            from sglang.srt.layers.attention.hybrid_attn_backend import (
                HybridAttnBackend,
            )

            attn_backend = HybridAttnBackend(
1880
                self,
1881
                decode_backend=self._get_attention_backend_from_str(
1882
1883
                    self.decode_attention_backend_str,
                    init_new_workspace=init_new_workspace,
1884
1885
                ),
                prefill_backend=self._get_attention_backend_from_str(
1886
1887
                    self.prefill_attention_backend_str,
                    init_new_workspace=init_new_workspace,
1888
1889
1890
1891
1892
1893
1894
1895
                ),
            )
            logger.info(
                f"Using hybrid attention backend for decode and prefill: "
                f"decode_backend={self.decode_attention_backend_str}, "
                f"prefill_backend={self.prefill_attention_backend_str}."
            )
            logger.warning(
1896
1897
                "Warning: Attention backend specified by --attention-backend or default backend might be overridden."
                "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1898
1899
1900
            )
        else:
            attn_backend = self._get_attention_backend_from_str(
1901
1902
                self.server_args.attention_backend,
                init_new_workspace=init_new_workspace,
1903
1904
            )

1905
1906
1907
1908
        (
            get_global_server_args().prefill_attention_backend,
            get_global_server_args().decode_attention_backend,
        ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
1909
1910
        return attn_backend

1911
1912
1913
    def _get_attention_backend_from_str(
        self, backend_str: str, init_new_workspace: bool = False
    ):
1914
        if backend_str not in ATTENTION_BACKENDS:
1915
            raise ValueError(f"Invalid attention backend: {backend_str}")
1916
        self.init_new_workspace = init_new_workspace
1917
1918
        full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
        return attn_backend_wrapper(self, full_attention_backend)
1919

Shuo Yang's avatar
Shuo Yang committed
1920
1921
1922
1923
1924
1925
1926
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1927
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1928
1929
1930
1931
1932
1933
1934
1935
1936
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1937
    def init_device_graphs(self):
1938
        """Capture device graphs."""
1939
        self.graph_runner = None
1940
        self.graph_mem_usage = 0
1941

1942
        if not self.is_generation:
1943
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1944
1945
            return

1946
1947
1948
1949
        if self.device != "cpu" and self.server_args.disable_cuda_graph:
            return

        if self.device == "cpu" and not self.server_args.enable_torch_compile:
1950
            return
1951

1952
        tic = time.perf_counter()
1953
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1954
        logger.info(
1955
            f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1956
        )
1957
1958
1959
1960
1961
1962
        graph_runners = defaultdict(
            lambda: CudaGraphRunner,
            {
                "cpu": CPUGraphRunner,
                "npu": NPUGraphRunner,
            },
1963
        )
1964
1965
        self.graph_runner = graph_runners[self.device](self)

1966
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1967
        self.graph_mem_usage = before_mem - after_mem
1968
        logger.info(
1969
1970
            f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
            f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1971
        )
1972

1973
1974
    def init_threads_binding(self):
        omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
1975
1976
        cpu_ids_by_node = get_cpu_ids_by_node()
        n_numa_node = len(cpu_ids_by_node)
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
        if omp_cpuids == "all":
            assert self.tp_size <= n_numa_node, (
                f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
                f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
                f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
                f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
                f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
                f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
                f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
                f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
            )
            if self.tp_size < n_numa_node:
                logger.warning(
                    f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
                )
            self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
        else:
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
            threads_bind_list = omp_cpuids.split("|")
            assert self.tp_size == len(threads_bind_list), (
                f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
                f"Please double check your settings."
            )
            self.local_omp_cpuid = threads_bind_list[self.tp_rank]
            if self.tp_size > n_numa_node:
                logger.warning(
                    f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
                    f"in this case the available memory amount of each rank cannot be determined in prior. "
                    f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
                )
2006

2007
    def apply_torch_tp(self):
2008
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
2009
        from sglang.srt.layers.model_parallel import tensor_parallel
2010
2011
2012
2013

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

2014
2015
2016
    def update_decode_attn_backend(self, stream_idx: int):
        self.decode_attn_backend = self.decode_attn_backend_group[stream_idx]

2017
    def forward_decode(
Cheng Wan's avatar
Cheng Wan committed
2018
2019
2020
2021
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
2022
    ) -> LogitsProcessorOutput:
Cheng Wan's avatar
Cheng Wan committed
2023
        if not skip_attn_backend_init:
2024
2025
2026
2027
2028
            if self.server_args.enable_pdmux:
                self.decode_attn_backend.init_forward_metadata(forward_batch)
                forward_batch.attn_backend = self.decode_attn_backend
            else:
                self.attn_backend.init_forward_metadata(forward_batch)
2029
2030
2031
2032
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
2033
        return self.model.forward(
2034
2035
2036
2037
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
2038
2039
        )

2040
    def forward_extend(
2041
2042
2043
2044
2045
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
2046
2047
2048
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

2049
2050
2051
2052
2053
2054
2055
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
2056
2057
2058
2059
2060

        if self.piecewise_cuda_graph_runner is not None:
            if self.piecewise_cuda_graph_runner.can_run(forward_batch):
                return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)

2061
2062
2063
2064
2065
2066
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2067

2068
2069
2070
2071
2072
2073
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
2074
        return self.model.forward(
2075
2076
2077
2078
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
2079
2080
        )

2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
    def forward_split_prefill(
        self,
        forward_batch: ForwardBatch,
        reinit_attn_backend: bool = False,
        forward_count: int = 1,
    ) -> LogitsProcessorOutput:
        if forward_batch.split_index == 0 or reinit_attn_backend:
            self.attn_backend.init_forward_metadata(forward_batch)
        next_split_index = min(
            forward_batch.split_index + forward_count,
            self.model_config.num_hidden_layers,
        )
        ret = self.model.forward_split_prefill(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            (forward_batch.split_index, next_split_index),
        )
        forward_batch.split_index = next_split_index
        return ret

2102
    def forward(
2103
2104
2105
2106
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
2107
2108
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
2109
2110
2111
2112
2113
2114
2115
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
        self.forward_pass_id += 1

        with get_global_expert_distribution_recorder().with_forward_pass(
            self.forward_pass_id,
            forward_batch,
        ):
2116
            output = self._forward_raw(
2117
2118
2119
2120
2121
                forward_batch,
                skip_attn_backend_init,
                pp_proxy_tensors,
                reinit_attn_backend,
                split_forward_count,
2122
2123
            )

2124
        if self.eplb_manager is not None:
2125
            self.eplb_manager.on_forward_pass_end()
2126
2127
2128

        return output

2129
2130
2131
2132
2133
    def _forward_raw(
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool,
        pp_proxy_tensors: Optional[PPProxyTensors],
2134
2135
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
2136
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
2137
2138
2139
2140
2141
2142
2143
        mode_check = (
            forward_batch.forward_mode.is_cpu_graph
            if self.device == "cpu"
            else forward_batch.forward_mode.is_cuda_graph
        )
        can_run_graph = bool(
            mode_check()
2144
2145
            and self.graph_runner
            and self.graph_runner.can_run(forward_batch)
2146
        )
2147
2148

        if can_run_graph:
2149
            ret = self.graph_runner.replay(
2150
2151
2152
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
2153
            )
2154
            return ret, can_run_graph
Cheng Wan's avatar
Cheng Wan committed
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165

        # For MLP sync
        if forward_batch.global_num_tokens_cpu is not None:
            forward_batch.prepare_mlp_sync_batch(self)

        if forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
            )
2166
2167
2168
2169
2170
2171
        elif forward_batch.forward_mode.is_split_prefill():
            ret = self.forward_split_prefill(
                forward_batch,
                reinit_attn_backend=reinit_attn_backend,
                forward_count=split_forward_count,
            )
2172
2173
2174
2175
2176
2177
        elif forward_batch.forward_mode.is_extend():
            ret = self.forward_extend(
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
            )
Ke Bao's avatar
Ke Bao committed
2178
        elif forward_batch.forward_mode.is_idle():
2179
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
2180
        else:
2181
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
2182

2183
2184
2185
2186
        if (
            forward_batch.global_num_tokens_cpu is not None
            and self.pp_group.is_last_rank
        ):
Cheng Wan's avatar
Cheng Wan committed
2187
2188
            forward_batch.post_forward_mlp_sync_batch(ret)

2189
        return ret, can_run_graph
2190

2191
2192
2193
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
2194
2195
2196
2197
2198
        # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
        #       was executed after we processed last batch's results.

        # Calculate logits bias and apply it to next_token_logits.
        sampling_info.update_regex_vocab_mask()
2199
2200
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
2221

2222
        self._preprocess_logits(logits_output, forward_batch.sampling_info)
2223
2224
2225
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
2226
            forward_batch.sampling_info,
2227
2228
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
2229
            forward_batch.token_ids_logprobs,
2230
2231
2232
2233
2234
2235
            # For prefill, we only use the position of the last token.
            (
                forward_batch.positions
                if forward_batch.forward_mode.is_decode()
                else forward_batch.seq_lens - 1
            ),
2236
        )
2237
2238
        return next_token_ids

2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
    def compute_logprobs_only(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> None:
        """
        Compute token_ids_logprobs without performing sampling.

        Optimized path for prefill-only requests that need token_ids_logprobs but don't
        require next token generation. Skips expensive sampling operations
        while still providing requested probability information.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output
        """
        if not forward_batch.token_ids_logprobs:
            return

        # Preprocess logits (same as in sample method)
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

        # Delegate to sampler for logprob-only computation
        # This populates logits_output with requested token probabilities
        self.sampler.compute_logprobs_only(
            logits_output,
            forward_batch.sampling_info,
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
            forward_batch.token_ids_logprobs,
        )

Yineng Zhang's avatar
Yineng Zhang committed
2271
2272
2273
2274
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
2275
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
2276
2277
        if rope_scaling is None:
            return False
2278
2279
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
2280

2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
    def update_weights_from_ipc(self, recv_req):
        """Update weights from IPC for checkpoint-engine integration."""
        try:
            from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
                SGLangCheckpointEngineWorkerExtensionImpl,
            )

            # Create a worker extension that integrates with SGLang's model
            worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
            worker.update_weights_from_ipc(recv_req.zmq_handles)
            return True, "IPC weight update completed successfully"
        except ImportError as e:
            return False, f"IPC weight update failed: ImportError {e}"
        except Exception as e:
            logger.error(f"IPC weight update failed: {e}")
            return False, str(e)

2314
2315
2316
2317
2318
2319
2320

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


2321
def _unwrap_tensor(tensor, tp_rank, device):
2322
    if isinstance(tensor, LocalSerializedTensor):
2323
        tensor = tensor.get(tp_rank)
2324
    return tensor.to(device)
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])