"docs/vscode:/vscode.git/clone" did not exist on "6057d3cf1c2f3a4c5072a3853a021bb8b4ce61f7"
model_runner.py 93.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
18
import inspect
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
23
import socket
import threading
24
import time
25
from collections import defaultdict
26
from dataclasses import dataclass
27
from typing import Callable, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29

import torch
30
import torch.distributed as dist
31

32
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
33
from sglang.srt.configs.device_config import DeviceConfig
34
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
fzyzcjy's avatar
fzyzcjy committed
35
36
37
38
39
40
from sglang.srt.configs.model_config import (
    AttentionArch,
    ModelConfig,
    get_nsa_index_head_dim,
    is_deepseek_nsa,
)
41
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
42
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
43
44
45
from sglang.srt.debug_utils.tensor_dump_forward_hook import (
    register_forward_hook_for_model,
)
46
from sglang.srt.distributed import (
47
    get_pp_group,
zhyncs's avatar
zhyncs committed
48
    get_tp_group,
49
    get_world_group,
zhyncs's avatar
zhyncs committed
50
51
    init_distributed_environment,
    initialize_model_parallel,
52
    set_custom_all_reduce,
53
    set_mscclpp_all_reduce,
54
    set_symm_mem_all_reduce,
zhyncs's avatar
zhyncs committed
55
)
56
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
57
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
fzyzcjy's avatar
fzyzcjy committed
58
59
60
61
62
63
64
65
66
67
68
69
70
from sglang.srt.eplb.eplb_manager import EPLBManager
from sglang.srt.eplb.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
    set_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_location import (
    ExpertLocationMetadata,
    compute_initial_expert_location_metadata,
    get_global_expert_location_metadata,
    set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
71
from sglang.srt.layers import deep_gemm_wrapper
72
73
74
75
from sglang.srt.layers.attention.attention_registry import (
    ATTENTION_BACKENDS,
    attn_backend_wrapper,
)
76
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
77
78
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
79
    get_attention_tp_size,
80
81
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
82
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
83
from sglang.srt.layers.sampler import Sampler
84
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
85
from sglang.srt.lora.lora_manager import LoRAManager
86
from sglang.srt.lora.lora_registry import LoRARef
87
88
89
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    PagedTokenToKVPoolAllocator,
tarinkk's avatar
tarinkk committed
90
    SWATokenToKVPoolAllocator,
91
92
    TokenToKVPoolAllocator,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
93
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
94
from sglang.srt.mem_cache.memory_pool import (
95
96
    AscendMLAPagedTokenToKVPool,
    AscendTokenToKVPool,
Shuo Yang's avatar
Shuo Yang committed
97
    DoubleSparseTokenToKVPool,
Yi Zhang's avatar
Yi Zhang committed
98
99
    HybridLinearKVPool,
    HybridReqToTokenPool,
100
101
    MHATokenToKVPool,
    MLATokenToKVPool,
fzyzcjy's avatar
fzyzcjy committed
102
    NSATokenToKVPool,
103
    ReqToTokenPool,
tarinkk's avatar
tarinkk committed
104
    SWAKVPool,
105
)
106
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
107
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
108
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
109
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
110
111
112
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
    PiecewiseCudaGraphRunner,
)
113
from sglang.srt.model_loader import get_model
114
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
115
116
117
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
    trigger_init_weights_send_group_for_remote_instance_request,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
118
from sglang.srt.model_loader.utils import set_default_torch_dtype
119
from sglang.srt.model_loader.weight_utils import default_weight_loader
120
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
121
122
123
124
125
from sglang.srt.server_args import (
    ServerArgs,
    get_global_server_args,
    set_global_server_args_for_scheduler,
)
126
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
127
from sglang.srt.utils import (
128
    MultiprocessingSerializer,
129
    cpu_has_amx_support,
130
    dynamic_import,
131
    enable_show_time_cost,
132
    get_available_gpu_memory,
133
    get_bool_env_var,
134
    get_cpu_ids_by_node,
135
    init_custom_process_group,
HAI's avatar
HAI committed
136
    is_hip,
137
    is_npu,
138
    log_info_on_rank0,
139
    monkey_patch_p2p_access_check,
140
    set_cuda_arch,
141
    slow_rank_detector,
142
    xpu_has_xmx_support,
143
)
144
145
146
147
148
from sglang.srt.utils.offloader import (
    create_offloader_from_server_args,
    get_offloader,
    set_offloader,
)
149
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
150
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
151
152
153
154
from sglang.srt.weight_sync.tensor_bucket import (
    FlattenedTensorBucket,
    FlattenedTensorMetadata,
)
155

156
157
158
159
160
161
162
163
164
165
MLA_ATTENTION_BACKENDS = [
    "aiter",
    "flashinfer",
    "fa3",
    "fa4",
    "triton",
    "flashmla",
    "cutlass_mla",
    "trtllm_mla",
    "ascend",
fzyzcjy's avatar
fzyzcjy committed
166
    "nsa",
167
168
]

169
170
171
172
173
174
175
176
177
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
    "flashinfer",
    "fa3",
    "fa4",
    "flashmla",
    "cutlass_mla",
    "trtllm_mla",
]

178
179
180
181
182
183
184

def add_mla_attention_backend(backend_name):
    if backend_name not in MLA_ATTENTION_BACKENDS:
        MLA_ATTENTION_BACKENDS.append(backend_name)
        logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")


185
186
187
188
189
190
191
192
def add_chunked_prefix_cache_attention_backend(backend_name):
    if backend_name not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS:
        CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS.append(backend_name)
        logger.info(
            f"Added {backend_name} to CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS."
        )


193
_is_hip = is_hip()
194
_is_npu = is_npu()
195
_is_cpu_amx_available = cpu_has_amx_support()
196
_is_xpu_xmx_available = xpu_has_xmx_support()
197

Lianmin Zheng's avatar
Lianmin Zheng committed
198
# Use a small KV cache pool size for tests in CI
199
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
200
201

# Detect stragger ranks in model loading
202
203
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

204
205
206
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3

Lianmin Zheng's avatar
Lianmin Zheng committed
207
208
logger = logging.getLogger(__name__)

209
210
211
212
213
214
215
if _is_npu:
    import torch_npu

    torch.npu.config.allow_internal_format = True
    torch_npu.npu.set_compile_mode(jit_compile=False)


216
217
218
219
220
221
222
223
224
225
226
227
228
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
229
class ModelRunner:
230
231
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
    def __init__(
        self,
234
        model_config: ModelConfig,
235
236
237
238
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
Cheng Wan's avatar
Cheng Wan committed
239
240
        moe_ep_rank: int,
        moe_ep_size: int,
241
242
        pp_rank: int,
        pp_size: int,
243
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
244
        server_args: ServerArgs,
fzyzcjy's avatar
fzyzcjy committed
245
        dp_rank: Optional[int] = None,
246
        is_draft_worker: bool = False,
247
        req_to_token_pool: Optional[ReqToTokenPool] = None,
248
        token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
249
    ):
250
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
251
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
252
        self.device = server_args.device
253
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
254
255
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Cheng Wan's avatar
Cheng Wan committed
256
257
        self.moe_ep_rank = moe_ep_rank
        self.moe_ep_size = moe_ep_size
258
        self.dp_size = server_args.dp_size
259
260
        self.pp_rank = pp_rank
        self.pp_size = pp_size
261
        self.model_config = model_config
Zhang, Liangang's avatar
Zhang, Liangang committed
262
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
263
        self.server_args = server_args
264
        self.is_draft_worker = is_draft_worker
265
266
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
267
268
269
        self.is_multimodal_chunked_prefill_supported = (
            model_config.is_multimodal_chunked_prefill_supported
        )
270
271
272
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
273
        self.page_size = server_args.page_size
274
275
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
tarinkk's avatar
tarinkk committed
276
        self.is_hybrid = model_config.is_hybrid
Baizhou Zhang's avatar
Baizhou Zhang committed
277
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
278
        self.attention_chunk_size = model_config.attention_chunk_size
279
        self.forward_pass_id = 0
280
        self.init_new_workspace = False
281

Lianmin Zheng's avatar
Lianmin Zheng committed
282
        # Apply the rank zero filter to logger
283
284
        if server_args.show_time_cost:
            enable_show_time_cost()
285

Lianmin Zheng's avatar
Lianmin Zheng committed
286
287
288
        # Model-specific adjustment
        self.model_specific_adjustment()

289
290
291
292
293
294
        # Set the global server_args in the scheduler process
        set_global_server_args_for_scheduler(server_args)
        global_server_args = get_global_server_args()

        # FIXME: hacky set `use_mla_backend`
        global_server_args.use_mla_backend = self.use_mla_backend
Lianmin Zheng's avatar
Lianmin Zheng committed
295

296
297
298
299
        # Init OpenMP threads binding for CPU
        if self.device == "cpu":
            self.init_threads_binding()

300
        # Get memory before model loading
301
        min_per_gpu_memory = self.init_torch_distributed()
302

303
        # CPU offload
fzyzcjy's avatar
fzyzcjy committed
304
        set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
305

fzyzcjy's avatar
fzyzcjy committed
306
307
308
        if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
            slow_rank_detector.execute()

309
        # Update deep gemm configure
310
311
        if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
            deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
312

Lianmin Zheng's avatar
Lianmin Zheng committed
313
        # Initialize the model runner
314
315
        self.initialize(min_per_gpu_memory)

Lianmin Zheng's avatar
Lianmin Zheng committed
316
        # Temporary cached values
317
318
319
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
320
321

        # For weight updates
322
        self._model_update_group = {}
323
        self._weights_send_group = {}
324

325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        if (
            self.server_args.enable_piecewise_cuda_graph
            and self.can_run_piecewise_cuda_graph()
        ):
            self.attention_layers = []
            for layer in self.model.model.layers:
                if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
                    self.attention_layers.append(layer.self_attn.attn)
            if len(self.attention_layers) < self.model_config.num_hidden_layers:
                # TODO(yuwei): support Non-Standard GQA
                log_info_on_rank0(
                    logger,
                    "Disable piecewise CUDA graph because some layers do not apply Standard GQA",
                )
                self.piecewise_cuda_graph_runner = None
            else:
                self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
        else:
            self.piecewise_cuda_graph_runner = None

345
346
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
347

348
349
350
351
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

352
353
        if not self.is_draft_worker:
            set_global_expert_location_metadata(
354
355
356
357
358
                compute_initial_expert_location_metadata(
                    server_args=server_args,
                    model_config=self.model_config,
                    moe_ep_rank=self.moe_ep_rank,
                )
359
360
361
362
363
            )
            if self.tp_rank == 0 and get_bool_env_var(
                "SGLANG_LOG_EXPERT_LOCATION_METADATA"
            ):
                logger.info(
364
                    f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
365
366
367
368
369
370
371
372
373
374
                )

            set_global_expert_distribution_recorder(
                ExpertDistributionRecorder.init_new(
                    server_args,
                    get_global_expert_location_metadata(),
                    rank=self.tp_rank,
                )
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
375
        # Expert parallelism
376
377
378
379
380
        self.eplb_manager = (
            EPLBManager(self)
            if self.server_args.enable_eplb and (not self.is_draft_worker)
            else None
        )
381
        self.expert_location_updater = ExpertLocationUpdater()
382

383
384
385
386
387
        (
            ElasticEPStateManager.init(self.server_args)
            if self.server_args.elastic_ep_backend
            else None
        )
388
        # Load the model
389
        self.sampler = Sampler()
390
        self.load_model()
391

392
        # Check if the model is using hybrid SWA
Hanming Lu's avatar
Hanming Lu committed
393
394
395
396
397
398
399
400
401
        if (
            not self.server_args.disable_hybrid_swa_memory
            and self.sliding_window_size is not None
            and self.sliding_window_size > 0
        ):
            architectures = self.model_config.hf_config.architectures
            if architectures and not any("Llama4" in arch for arch in architectures):
                self.is_hybrid = self.model_config.is_hybrid = True

402
        if config := self.mamba2_config:
403
404
            class_name = config.__class__.__name__
            logger.warning(f"{class_name} model detected, disable radix cache")
Yi Zhang's avatar
Yi Zhang committed
405
406
            self.server_args.disable_radix_cache = True

407
408
409
410
411
412
413
        # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
        # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
        # determine the number of layers.
        model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
        model_num_layers = (
            self.model_config.num_nextn_predict_layers
            if self.is_draft_worker and model_has_mtp_layers
414
415
416
417
            else max(
                self.model_config.num_hidden_layers,
                self.model_config.num_attention_layers,
            )
418
        )
419
420
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(self.model, "end_layer", model_num_layers)
421
        self.num_effective_layers = self.end_layer - self.start_layer
422
423
424
425
426
427
428
        assert (
            (not model_has_mtp_layers)
            or (self.spec_algorithm.is_none())
            or (
                (not self.spec_algorithm.is_none())
                and (self.num_effective_layers == model_num_layers)
            )
429
        ), "PP is not compatible with MTP models."
430

431
        # Apply torchao quantization
432
433
434
435
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
436
                self.model, get_global_server_args().torchao_config
437
            )
438

439
        # Apply torch TP if the model supports it
440
441
442
443
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

444
        # Init lora
445
        if server_args.enable_lora:
446
            self.init_lora_manager()
447

448
449
450
451
452
453
454
455
        # Init Double Sparsity
        if server_args.enable_double_sparsity:
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

456
457
        # Enable batch invariant mode
        if server_args.enable_deterministic_inference:
458
            from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
459
460
461

            enable_batch_invariant_mode()

462
        # Init memory pool and attention backends
463
464
        self.init_memory_pool(
            min_per_gpu_memory,
465
            server_args.max_running_requests,
466
467
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
468
469
470
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
471
            self.init_device_graphs()
472
        elif self.device in ["npu", "cpu"]:
473
474
            self.init_attention_backend()
            self.init_device_graphs()
Zhang, Liangang's avatar
Zhang, Liangang committed
475
        else:
476
            self.graph_runner = None
477
            self.graph_mem_usage = 0
Zhang, Liangang's avatar
Zhang, Liangang committed
478
            self.init_attention_backend()
479

James Liu's avatar
James Liu committed
480
481
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
lukec's avatar
lukec committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
            # load draft config
            draft_model_config = ModelConfig.from_server_args(
                server_args,
                model_path=(server_args.speculative_draft_model_path),
                is_draft_model=True,
            )

            try:
                # get the aux layer from draft model config
                eagle_config = getattr(
                    draft_model_config.hf_config, "eagle_config", None
                )
                eagle_aux_hidden_state_layer_ids = eagle_config[
                    "eagle_aux_hidden_state_layer_ids"
                ]
            except:
                # if there is no aux layer, set to None
                eagle_aux_hidden_state_layer_ids = None

            self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
James Liu's avatar
James Liu committed
502

503
504
505
506
    def model_specific_adjustment(self):
        server_args = self.server_args

        if server_args.enable_double_sparsity:
507
508
509
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
510
511
512
513
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True

        if self.is_multimodal:
514
515
516
            if not self.is_multimodal_chunked_prefill_supported:
                server_args.chunked_prefill_size = -1
                logger.info(
517
                    f"Automatically turn off --chunked-prefill-size as it is not supported for "
518
519
                    f"{self.model_config.hf_config.model_type}"
                )
520

521
522
523
524
525
        if (
            not self.use_mla_backend
            or server_args.attention_backend
            not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS
        ):
526
            server_args.disable_chunked_prefix_cache = True
527

528
        if not server_args.disable_chunked_prefix_cache:
529
            log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
530

531
532
533
534
535
        if self.model_config.hf_config.model_type == "qwen3_vl_moe":
            if (
                quantization_config := getattr(
                    self.model_config.hf_config, "quantization_config", None
                )
536
            ) is not None and "weight_block_size" in quantization_config:
537
                weight_block_size_n = quantization_config["weight_block_size"][0]
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553

                if self.tp_size % self.moe_ep_size != 0:
                    raise ValueError(
                        f"tp_size {self.tp_size} must be divisible by moe_ep_size {self.moe_ep_size}"
                    )
                moe_tp_size = self.tp_size // self.moe_ep_size

                moe_intermediate_size = (
                    self.model_config.hf_text_config.moe_intermediate_size
                )
                if moe_intermediate_size % moe_tp_size != 0:
                    raise ValueError(
                        f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})."
                    )

                if (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0:
554
                    raise ValueError(
555
556
557
                        f"For qwen3-vl-fp8 models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 "
                        f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size}). "
                        f"You can fix this by setting arguments `--tp-size` and `--ep-size` correctly."
558
559
                    )

560
    def init_torch_distributed(self):
561
        logger.info("Init torch distributed begin.")
562

563
564
565
566
567
568
569
570
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
571
        if self.device == "cuda":
572
573
574
575
576
577
578
579
580
581
582
583
            if self.server_args.elastic_ep_backend == "mooncake":
                backend = "mooncake"
                if self.server_args.mooncake_ib_device:
                    mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
                    try:
                        from mooncake import ep as mooncake_ep

                        mooncake_ep.set_device_filter(mooncake_ib_device)
                    except:
                        pass  # A warning will be raised in `init_distributed_environment`
            else:
                backend = "nccl"
584
        elif self.device == "xpu":
585
            backend = "xccl"
586
587
        elif self.device == "hpu":
            backend = "hccl"
588
589
        elif self.device == "cpu":
            backend = "gloo"
590
591
        elif self.device == "npu":
            backend = "hccl"
592

593
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
594
        if not self.server_args.enable_p2p_check:
595
596
            monkey_patch_p2p_access_check()

597
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
598
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
599
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
600
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
601
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
602
        set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
603
        set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
604
605

        if not self.is_draft_worker:
606
607
608
609
            if self.device == "cpu":
                if _is_cpu_amx_available:
                    # Bind OpenMP threads to CPU cores
                    torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
610
611
612
613

                    # Set local size to hint SGLang to use shared memory based AllReduce
                    os.environ["LOCAL_SIZE"] = str(self.tp_size)
                    torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
614
615
616
617
618

                    @torch.library.register_fake("sgl_kernel::shm_allgather")
                    def _(data, dim):
                        return torch.cat([data] * self.tp_size, dim=dim)

619
620
                else:
                    logger.warning(
621
                        "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
622
623
                    )

Mick's avatar
Mick committed
624
            # Only initialize the distributed environment on the target model worker.
625
626
            init_distributed_environment(
                backend=backend,
627
628
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
629
630
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
631
                timeout=self.server_args.dist_timeout,
632
            )
633
634
635
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
Cheng Wan's avatar
Cheng Wan committed
636
                expert_model_parallel_size=self.moe_ep_size,
637
                duplicate_tp_group=self.server_args.enable_pdmux,
638
                torch_compile=self.server_args.enable_piecewise_cuda_graph,
639
            )
640
            initialize_dp_attention(
641
642
                server_args=self.server_args,
                model_config=self.model_config,
643
            )
644

645
        min_per_gpu_memory = get_available_gpu_memory(
646
647
648
649
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
650
        )
651
        self.tp_group = get_tp_group()
652
        self.pp_group = get_pp_group()
653
        self.attention_tp_group = get_attention_tp_group()
654

655
        # Check memory for tensor parallelism
656
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
657
        if self.tp_size > 1 and not self.is_draft_worker:
658
            if min_per_gpu_memory < local_gpu_memory * 0.9:
659
660
661
662
663
664
665
666
667
668
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
669

670
671
672
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
673
        return min_per_gpu_memory
674

Lianmin Zheng's avatar
Lianmin Zheng committed
675
    def load_model(self):
676
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
677
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
678
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
679
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
680
681

        # This can reduce thread conflicts and speed up weight loading.
682
683
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
684
685
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
686
687
688
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
689
                self.server_args.dtype = "float16"
690
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
691
692
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
693

694
695
        set_cuda_arch()

696
        # Prepare the model config
697
698
699
700
701
702
703
704
705
706
        from sglang.srt.configs.modelopt_config import ModelOptConfig

        modelopt_config = ModelOptConfig(
            quant=self.server_args.modelopt_quant,
            checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
            checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
            export_path=self.server_args.modelopt_export_path,
            quantize_and_serve=self.server_args.quantize_and_serve,
        )

707
708
709
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
710
            model_loader_extra_config=self.server_args.model_loader_extra_config,
711
712
713
714
            tp_rank=self.tp_rank,
            remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
            remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
            remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
715
            modelopt_config=modelopt_config,
716
        )
717
718
719
720
        if self.device == "cpu":
            self.model_config = adjust_config_with_unaligned_cpu_tp(
                self.model_config, self.load_config, self.tp_size
            )
721

722
723
724
725
726
727
728
729
730
731
732
733
734
735
        if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
            if self.tp_rank == 0:
                instance_ip = socket.gethostbyname(socket.gethostname())
                t = threading.Thread(
                    target=trigger_init_weights_send_group_for_remote_instance_request,
                    args=(
                        self.server_args.remote_instance_weight_loader_seed_instance_ip,
                        self.server_args.remote_instance_weight_loader_seed_instance_service_port,
                        self.server_args.remote_instance_weight_loader_send_weights_group_ports,
                        instance_ip,
                    ),
                )
                t.start()

736
        # Load the model
737
738
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
739

740
741
742
743
        with self.memory_saver_adapter.region(
            GPU_MEMORY_TYPE_WEIGHTS,
            enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
        ):
744
745
746
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
747
                device_config=DeviceConfig(self.device, self.gpu_id),
748
            )
749
        monkey_patch_vllm_parallel_state(reverse=True)
750

751
752
        get_offloader().post_init()

bjmsong's avatar
bjmsong committed
753
754
755
756
757
758
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
759
760
761
762
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
763
764
765
766
767
768
769
770
771
772
773
774
775
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

776
        # Parse other args
Hanming Lu's avatar
Hanming Lu committed
777
778
779
780
781
        self.sliding_window_size = None
        if hasattr(self.model, "get_attention_sliding_window_size"):
            self.sliding_window_size = self.model.get_attention_sliding_window_size()
        elif self.model_config.attention_chunk_size is not None:
            self.sliding_window_size = self.model_config.attention_chunk_size
782
            logger.info(
Hanming Lu's avatar
Hanming Lu committed
783
784
785
                f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
            )

786
        self.dtype = self.model_config.dtype
787

788
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
789
        self.weight_load_mem_usage = before_avail_memory - after_avail_memory
790
        logger.info(
791
            f"Load weight end. "
792
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
793
            f"dtype={self.dtype}, "
794
            f"avail mem={after_avail_memory:.2f} GB, "
795
            f"mem usage={self.weight_load_mem_usage:.2f} GB."
796
        )
797
798
799
800
801
802
803
804
805
        if self.server_args.debug_tensor_dump_output_folder is not None:
            register_forward_hook_for_model(
                self.model,
                self.server_args.debug_tensor_dump_output_folder,
                self.server_args.debug_tensor_dump_layers,
                self.tp_size,
                self.tp_rank,
                self.pp_rank,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
806

807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
        if self.server_args.elastic_ep_backend == "mooncake":
            # Mooncake does not support `monitored_barrier`
            dist.barrier(group=get_tp_group().cpu_group)
        else:
            # Handle the case where some ranks do not finish loading.
            try:
                dist.monitored_barrier(
                    group=get_tp_group().cpu_group,
                    timeout=datetime.timedelta(
                        seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S
                    ),
                    wait_all_ranks=True,
                )
            except RuntimeError:
                raise ValueError(
                    f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
                ) from None
824

825
    def update_expert_location(
826
827
828
        self,
        new_expert_location_metadata: ExpertLocationMetadata,
        update_layer_ids: List[int],
829
    ):
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
        if ElasticEPStateManager.instance() is not None:
            # TODO: refactor the weights update when elastic ep
            old_expert_location_metadata = get_global_expert_location_metadata()
            assert old_expert_location_metadata is not None
            old_expert_location_metadata.update(
                new_expert_location_metadata,
                update_layer_ids=update_layer_ids,
            )
            self.update_weights_from_disk(
                self.server_args.model_path,
                self.server_args.load_format,
                lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
            )
        else:
            self.expert_location_updater.update(
                self.model.routed_experts_weights_of_layer,
                new_expert_location_metadata,
                update_layer_ids=update_layer_ids,
                nnodes=self.server_args.nnodes,
                rank=self.tp_rank,
            )
851

852
    def update_weights_from_disk(
853
854
855
856
        self,
        model_path: str,
        load_format: str,
        weight_name_filter: Optional[Callable[[str], bool]] = None,
857
858
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
859
        logger.info(
Chayenne's avatar
Chayenne committed
860
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
861
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
862
863
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
864
        target_device = torch.device(self.device)
865
        self.model_config.model_path = model_path
866
867
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
868
        # Only support DefaultModelLoader for now
869
        loader = get_model_loader(load_config, self.model_config)
870
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
871
872
            message = f"Failed to get model loader: {loader}."
            return False, message
873
874
875

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
876
                DefaultModelLoader.Source.init_new(config, self.model)
877
            )
878
879
880
881
882
            if weight_name_filter is not None:
                iter = (
                    (name, weight) for name, weight in iter if weight_name_filter(name)
                )

883
884
885
            return iter

        def model_load_weights(model, iter):
886
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
887
888
            return model

889
        with set_default_torch_dtype(self.model_config.dtype):
890
            try:
891
                iter = get_weight_iter(self.model_config)
892
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
893
                message = f"Failed to get weights iterator: {e}."
894
895
896
897
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
898
899
900
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
901
902
                del iter
                gc.collect()
903
                iter = get_weight_iter(self.model_config)
904
905
906
907
908
909
910
911
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

912
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
913
        return True, "Succeeded to update model weights."
914

915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
    def init_weights_send_group_for_remote_instance(
        self,
        master_address,
        ports,
        group_rank,
        world_size,
        group_name,
        backend="nccl",
    ):
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        ports_list = ports.split(",")
        assert (
            len(ports_list) == self.tp_size
        ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
        group_port = ports_list[self.tp_rank]
        group_name = f"{group_name}_{group_port}_{self.tp_rank}"

        logger.info(
            f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
            f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
        )

        torch.cuda.empty_cache()
        success = False
        message = ""
        try:
            self._weights_send_group[group_name] = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{group_port}",
                world_size=world_size,
                rank=group_rank,
                group_name=group_name,
                device_id=torch.device("cuda", self.gpu_id),
            )
            dist.barrier(group=self._weights_send_group[group_name])
            success = True
            message = (
                f"Succeeded to init group through {master_address}:{group_port} group."
            )
        except Exception as e:
            message = f"Failed to init group: {e}."
            logger.error(message)

        torch.cuda.empty_cache()
        return success, message

    def send_weights_to_remote_instance(
        self,
        master_address,
        ports,
        group_name,
    ):
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        ports_list = ports.split(",")
        assert (
            len(ports_list) == self.tp_size
        ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
        group_port = ports_list[self.tp_rank]
        group_name = f"{group_name}_{group_port}_{self.tp_rank}"

        if self._weights_send_group[group_name] is not None:
            send_group = self._weights_send_group[group_name]
        else:
            message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
            logger.error(message)
            return False, message

        torch.cuda.empty_cache()
        success = False
        message = ""
        try:
            for _, weights in self.model.named_parameters():
                torch.distributed.broadcast(
                    weights,
                    src=0,
                    group=send_group,
                )
            success = True
            message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
        except Exception as e:
            message = f"Failed to send weights: {e}."
            logger.error(message)

        # destroy the process group after sending weights
        del self._weights_send_group[group_name]
        torch.distributed.distributed_c10d.destroy_process_group(send_group)
        torch.cuda.empty_cache()
        return success, message

1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
1040
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
1041
1042
1043
        )

        try:
1044
            self._model_update_group[group_name] = init_custom_process_group(
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
    def destroy_weights_update_group(self, group_name):
        try:
            if group_name in self._model_update_group:
                pg = self._model_update_group.pop(group_name)
                torch.distributed.destroy_process_group(pg)
                return True, "Succeeded to destroy custom process group."
            else:
                return False, "The group to be destroyed does not exist."
        except Exception as e:
            message = f"Failed to destroy custom process group: {e}."
            logger.error(message)
            return False, message

1070
    def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """

1081
1082
1083
1084
        assert group_name in self._model_update_group, (
            f"Group {group_name} not in {list(self._model_update_group.keys())}. "
            "Please call `init_weights_update_group` first."
        )
1085
1086

        try:
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
            weights = []
            handles = []
            for name, dtype, shape in zip(names, dtypes, shapes):
                target_dtype = (
                    dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
                )
                weight = torch.empty(shape, dtype=target_dtype, device=self.device)
                handles.append(
                    torch.distributed.broadcast(
                        weight,
                        src=0,
                        group=self._model_update_group[group_name],
                        async_op=True,
                    )
                )
                weights.append((name, weight))
            for handle in handles:
                handle.wait()

            self.model.load_weights(weights)
1107
            return True, "Succeeded to update parameter online."
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

1118
1119
1120
1121
1122
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
1123
        monkey_patch_torch_reductions()
1124
1125
1126
1127
1128
1129
        if load_format == "flattened_bucket":
            # Handle flattened bucket format
            return self._update_weights_from_flattened_bucket(
                flattened_tensor_bucket_dict=named_tensors
            )

1130
        # We need to get device after patch otherwise the device would be wrong
1131
1132
        self.device_module = torch.get_device_module(self.device)
        infered_device = self.device_module.current_device()
1133

1134
        named_tensors = [
1135
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
1136
1137
1138
1139
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
1140
1141
1142
        elif load_format in self.server_args.custom_weight_loader:
            custom_loader = dynamic_import(load_format)
            custom_loader(self.model, named_tensors)
1143
1144
1145
1146
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
1147
        return True, "Success"
1148

1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
    def _update_weights_from_flattened_bucket(
        self,
        flattened_tensor_bucket_dict,
    ):
        """Handle flattened bucket format for weight updates"""
        flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
        metadata = flattened_tensor_bucket_dict["metadata"]

        # Convert metadata dict to our format
        converted_metadata = []
        for meta in metadata:
            converted_meta = FlattenedTensorMetadata(
                name=meta.name,
                shape=meta.shape,
                dtype=meta.dtype,
                start_idx=meta.start_idx,
                end_idx=meta.end_idx,
                numel=meta.numel,
            )
            converted_metadata.append(converted_meta)

        # Create bucket and reconstruct tensors
        bucket = FlattenedTensorBucket(
            flattened_tensor=flattened_tensor, metadata=converted_metadata
        )
        reconstructed_tensors = bucket.reconstruct_tensors()

        # Load the reconstructed tensors using the standard method
        self.model.load_weights(reconstructed_tensors)

        return True, "Success"

1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

1198
1199
1200
1201
1202
1203
1204
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
1205
            lora_backend=self.server_args.lora_backend,
1206
1207
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
1208
1209
            max_lora_rank=self.server_args.max_lora_rank,
            target_modules=self.server_args.lora_target_modules,
1210
            lora_paths=self.server_args.lora_paths,
1211
            server_args=self.server_args,
1212
        )
1213

1214
    def load_lora_adapter(self, lora_ref: LoRARef):
1215
1216
1217
        """Load a new lora adapter from disk or huggingface."""

        logger.info(
1218
            f"LoRA adapter loading starts: {lora_ref}. "
1219
1220
1221
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1222
        result = self.lora_manager.load_lora_adapter(lora_ref)
1223
1224

        logger.info(
1225
            f"LoRA adapter loading completes: {lora_ref}. "
1226
1227
1228
1229
1230
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result

1231
    def unload_lora_adapter(self, lora_ref: LoRARef):
1232
1233
1234
        """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""

        logger.info(
1235
            f"LoRA adapter unloading starts: {lora_ref}. "
1236
1237
1238
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1239
        result = self.lora_manager.unload_lora_adapter(lora_ref)
1240
1241

        logger.info(
1242
            f"LoRA adapter unloading completes: {lora_ref}. "
1243
1244
1245
1246
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result
1247

1248
    def profile_max_num_token(self, total_gpu_memory: int):
1249
        available_gpu_memory = get_available_gpu_memory(
1250
1251
1252
1253
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
1254
        )
1255
1256
1257
1258
1259
        if self.is_draft_worker:
            num_layers = getattr(
                self.model_config.hf_config,
                "num_nextn_predict_layers",
                self.num_effective_layers,
1260
            )
1261
1262
        elif config := self.mambaish_config:
            num_layers = len(config.full_attention_layer_ids)
1263
1264
1265
        else:
            num_layers = self.num_effective_layers
        if self.use_mla_backend:
1266
1267
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
1268
                * num_layers
1269
                * torch._utils._element_size(self.kv_cache_dtype)
1270
            )
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
            # Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
            if is_deepseek_nsa(self.model_config.hf_config):
                index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
                indexer_size_per_token = (
                    index_head_dim
                    + index_head_dim // NSATokenToKVPool.quant_block_size * 4
                )
                element_size = torch._utils._element_size(
                    NSATokenToKVPool.index_k_with_scale_buffer_dtype
                )
                cell_size += indexer_size_per_token * num_layers * element_size
1282
1283
        else:
            cell_size = (
1284
                self.model_config.get_num_kv_heads(get_attention_tp_size())
1285
                * self.model_config.head_dim
1286
                * num_layers
1287
                * 2
1288
                * torch._utils._element_size(self.kv_cache_dtype)
1289
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1290
1291
1292
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
1293
1294
        if self.mambaish_config is not None:
            rest_memory = self.handle_max_mamba_cache(rest_memory)
1295
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1296
1297
        return max_num_token

1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
    def handle_max_mamba_cache(self, total_rest_memory):
        config = self.mambaish_config
        server_args = self.server_args
        assert config is not None

        speculativa_ratio = (
            0
            if server_args.speculative_num_draft_tokens is None
            else server_args.speculative_num_draft_tokens
        )
        if (
            server_args.disable_radix_cache
            or config.mamba2_cache_params.mamba_cache_per_req == 0
        ):
            # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
            if server_args.max_mamba_cache_size is None:
                if server_args.max_running_requests is not None:
                    server_args.max_mamba_cache_size = server_args.max_running_requests
                else:
                    server_args.max_mamba_cache_size = 512
        else:
            # allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
            # solve the equations:
            # 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
            # 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
            mamba_state_memory_raw = (
                total_rest_memory
                * server_args.mamba_full_memory_ratio
                / (1 + server_args.mamba_full_memory_ratio)
            )
            # calculate the max_mamba_cache_size based on the given total mamba memory
            server_args.max_mamba_cache_size = int(
                (mamba_state_memory_raw * (1 << 30))
                // config.mamba2_cache_params.mamba_cache_per_req
                // (1 + speculativa_ratio)
            )

        if self.hybrid_gdn_config is not None:
            server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
                server_args.dp_size if server_args.enable_dp_attention else 1
            )
        mamba_state_memory = (
            server_args.max_mamba_cache_size
            * config.mamba2_cache_params.mamba_cache_per_req
            * (1 + speculativa_ratio)
            / (1 << 30)
        )
        return total_rest_memory - mamba_state_memory

Yi Zhang's avatar
Yi Zhang committed
1347
    @property
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
    def hybrid_gdn_config(self):
        config = self.model_config.hf_config
        if isinstance(config, Qwen3NextConfig):
            return config
        return None

    @property
    def mamba2_config(self):
        config = self.model_config.hf_config
        if isinstance(config, FalconH1Config | NemotronHConfig):
            return config
        return None

    @property
    def mambaish_config(self):
        return self.mamba2_config or self.hybrid_gdn_config
Yi Zhang's avatar
Yi Zhang committed
1364

tarinkk's avatar
tarinkk committed
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
    def set_num_token_hybrid(self):
        if (
            "Llama4ForConditionalGeneration"
            in self.model_config.hf_config.architectures
        ):
            temp_ratio = (
                (1 - self.is_hybrid)
                + self.is_hybrid
                * self.attention_chunk_size
                / self.model_config.context_len
            )
            self.swa_max_total_num_tokens = (
                4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.full_max_total_num_tokens = (
                4 * self.max_total_num_tokens
                - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.swa_max_total_num_tokens = int(
                self.swa_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.full_max_total_num_tokens = int(
                self.full_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens
        else:
Hanming Lu's avatar
Hanming Lu committed
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
            assert self.sliding_window_size is not None and self.sliding_window_size > 0
            full_attention_layer_ids = []
            swa_attention_layer_ids = []

            try:
                layers = self.model.model.layers
            except:
                try:
                    layers = self.model.language_model.model.layers
                except:
1405
1406
1407
1408
1409
                    try:
                        layers = self.model.language_model.layers
                    except:
                        self.is_hybrid = False
                        return
Hanming Lu's avatar
Hanming Lu committed
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444

            for layer in layers:
                if (
                    layer.self_attn.attn.sliding_window_size is None
                    or layer.self_attn.attn.sliding_window_size == -1
                ):
                    full_attention_layer_ids.append(layer.layer_id)
                else:
                    swa_attention_layer_ids.append(layer.layer_id)
            self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
            self.model_config.full_attention_layer_ids = full_attention_layer_ids

            # Algorithm:
            # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
            # - Find total # of tokens available across layers.
            # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
            total_tokens = (
                self.max_total_num_tokens * self.model_config.num_hidden_layers
            )
            full_layers_num = len(full_attention_layer_ids)
            swa_layers_num = len(swa_attention_layer_ids)
            swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio

            # Solve the equations:
            # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
            # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
            denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
            self.full_max_total_num_tokens = int(total_tokens / denominator)
            self.swa_max_total_num_tokens = int(
                self.full_max_total_num_tokens * swa_full_tokens_ratio
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens

            logger.info(
                f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
tarinkk's avatar
tarinkk committed
1445
1446
            )

1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
    def can_run_piecewise_cuda_graph(self):
        if self.server_args.disable_cuda_graph:
            log_info_on_rank0(
                logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
            )
            return False
        if self.server_args.enable_torch_compile:
            log_info_on_rank0(
                logger,
                "Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
            )
            return False
        if self.pp_size > 1:
            # TODO(yuwei): support PP
            log_info_on_rank0(
                logger,
                "Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
            )
            return False
        return True

1468
    def init_memory_pool(
1469
1470
        self,
        total_gpu_memory: int,
1471
1472
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
1473
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1474
        # Determine the kv cache dtype
1475
        if self.server_args.kv_cache_dtype == "auto":
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
            quant_config = getattr(self.model, "quant_config", None)
            kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
            if (
                isinstance(kv_cache_quant_algo, str)
                and kv_cache_quant_algo.upper() == "FP8"
            ):
                if _is_hip:
                    self.kv_cache_dtype = torch.float8_e4m3fnuz
                else:
                    self.kv_cache_dtype = torch.float8_e4m3fn
            else:
                self.kv_cache_dtype = self.dtype
1488
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1489
            if _is_hip:  # Using natively supported format
HAI's avatar
HAI committed
1490
1491
1492
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
1493
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
1494
1495
1496
            if _is_hip:  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e4m3fnuz
            else:
bjmsong's avatar
bjmsong committed
1497
                self.kv_cache_dtype = torch.float8_e4m3fn
1498
1499
        elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
            self.kv_cache_dtype = torch.bfloat16
1500
1501
1502
1503
1504
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

1505
1506
        log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")

1507
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
Lianmin Zheng's avatar
Lianmin Zheng committed
1508
1509
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )
1521

1522
        if self.mambaish_config is not None:
1523
1524
1525
1526
1527
1528
1529
1530
            ratio = (
                MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
                if not self.server_args.disable_radix_cache
                else 1
            )
            max_num_reqs = min(
                max_num_reqs, self.server_args.max_mamba_cache_size // ratio
            )
1531

1532
        if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1533
1534
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1535
                max_num_reqs = self.server_args.max_num_reqs
1536
            else:
1537
1538
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
1539
1540
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
1541
1542
1543
1544
1545
1546
1547
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
1548
1549
                    + 100
                )
1550
1551
1552
1553
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
1554

1555
1556
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
1557
                logging.warning(
1558
1559
1560
1561
1562
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
1563

1564
1565
1566
1567
1568
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
        # different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
        if self.pp_size > 1:
            tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
            torch.distributed.all_reduce(
                tensor,
                op=torch.distributed.ReduceOp.MIN,
                group=get_world_group().cpu_group,
            )
            self.max_total_num_tokens = tensor.item()

tarinkk's avatar
tarinkk committed
1579
1580
1581
1582
        # create token size for hybrid cache
        if self.is_hybrid:
            self.set_num_token_hybrid()

1583
        if self.max_total_num_tokens <= 0:
1584
            raise RuntimeError(
1585
1586
                f"Not enough memory. Please try to increase --mem-fraction-static. "
                f"Current value: {self.server_args.mem_fraction_static=}"
1587
            )
1588

Lianmin Zheng's avatar
Lianmin Zheng committed
1589
        # Initialize req_to_token_pool
1590
        if self.req_to_token_pool is None:
1591
1592
1593
1594
1595
            # FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
            extra_max_context_len = 4
            if self.server_args.speculative_num_draft_tokens is not None:
                extra_max_context_len += self.server_args.speculative_num_draft_tokens

Byron Hsu's avatar
Byron Hsu committed
1596
            if self.server_args.disaggregation_mode == "decode":
1597
1598
1599
1600
                from sglang.srt.disaggregation.decode import (
                    DecodeReqToTokenPool,
                    HybridMambaDecodeReqToTokenPool,
                )
Byron Hsu's avatar
Byron Hsu committed
1601
1602
1603
1604

                # subscribe memory for pre-allocated requests
                # if max_num_reqs <= 32, we pre-allocate 2x requests
                pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
                if config := self.mambaish_config:
                    self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
                        size=max_num_reqs,
                        max_context_len=self.model_config.context_len
                        + extra_max_context_len,
                        device=self.device,
                        enable_memory_saver=self.server_args.enable_memory_saver,
                        cache_params=config.mamba2_cache_params,
                        speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
                        pre_alloc_size=pre_alloc_size,
                    )
                else:
                    self.req_to_token_pool = DecodeReqToTokenPool(
                        size=max_num_reqs,
                        max_context_len=self.model_config.context_len
                        + extra_max_context_len,
                        device=self.device,
                        enable_memory_saver=self.server_args.enable_memory_saver,
                        pre_alloc_size=pre_alloc_size,
                    )
1625
            elif config := self.mambaish_config:
Yi Zhang's avatar
Yi Zhang committed
1626
1627
                self.req_to_token_pool = HybridReqToTokenPool(
                    size=max_num_reqs,
1628
                    mamba_size=self.server_args.max_mamba_cache_size,
Yi Zhang's avatar
Yi Zhang committed
1629
1630
1631
1632
                    max_context_len=self.model_config.context_len
                    + extra_max_context_len,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
1633
                    cache_params=config.mamba2_cache_params,
Yi Zhang's avatar
Yi Zhang committed
1634
1635
                    speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
                )
Byron Hsu's avatar
Byron Hsu committed
1636
1637
1638
            else:
                self.req_to_token_pool = ReqToTokenPool(
                    size=max_num_reqs,
1639
1640
                    max_context_len=self.model_config.context_len
                    + extra_max_context_len,
Byron Hsu's avatar
Byron Hsu committed
1641
1642
1643
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
1644
1645
1646
1647
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

Lianmin Zheng's avatar
Lianmin Zheng committed
1648
        # Initialize token_to_kv_pool
fzyzcjy's avatar
fzyzcjy committed
1649
        is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
1650
1651
1652
1653
1654
1655
1656
1657
        if self.server_args.attention_backend == "ascend":
            if self.use_mla_backend:
                self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    kv_lora_rank=self.model_config.kv_lora_rank,
                    qk_rope_head_dim=self.model_config.qk_rope_head_dim,
fzyzcjy's avatar
fzyzcjy committed
1658
                    index_head_dim=self.model_config.index_head_dim,
Lianmin Zheng's avatar
Lianmin Zheng committed
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
                    layer_num=self.num_effective_layers,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
                )
            else:
                self.token_to_kv_pool = AscendTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
Makcum888e's avatar
Makcum888e committed
1674
                    layer_num=self.num_effective_layers,
Lianmin Zheng's avatar
Lianmin Zheng committed
1675
1676
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
Makcum888e's avatar
Makcum888e committed
1677
1678
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
Lianmin Zheng's avatar
Lianmin Zheng committed
1679
                )
fzyzcjy's avatar
fzyzcjy committed
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
        elif self.use_mla_backend and is_nsa_model:
            self.token_to_kv_pool = NSATokenToKVPool(
                self.max_total_num_tokens,
                page_size=self.page_size,
                dtype=self.kv_cache_dtype,
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.num_effective_layers,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
                start_layer=self.start_layer,
                end_layer=self.end_layer,
                index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
            )
1694
        elif self.use_mla_backend:
fzyzcjy's avatar
fzyzcjy committed
1695
            assert not is_nsa_model
1696
1697
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1698
                page_size=self.page_size,
1699
                dtype=self.kv_cache_dtype,
1700
1701
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1702
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
1703
                device=self.device,
1704
                enable_memory_saver=self.server_args.enable_memory_saver,
1705
1706
                start_layer=self.start_layer,
                end_layer=self.end_layer,
1707
            )
Shuo Yang's avatar
Shuo Yang committed
1708
1709
1710
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1711
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
1712
                dtype=self.kv_cache_dtype,
1713
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
1714
                head_dim=self.model_config.head_dim,
1715
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
1716
1717
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
1718
                enable_memory_saver=self.server_args.enable_memory_saver,
1719
1720
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
1721
            )
1722
        else:
tarinkk's avatar
tarinkk committed
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
            if self.is_hybrid:
                self.token_to_kv_pool = SWAKVPool(
                    size=self.full_max_total_num_tokens,
                    size_swa=self.swa_max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
                    full_attention_layer_ids=self.model_config.full_attention_layer_ids,
                    enable_kvcache_transpose=False,
                    device=self.device,
                )
1737
            elif config := self.mambaish_config:
Yi Zhang's avatar
Yi Zhang committed
1738
                self.token_to_kv_pool = HybridLinearKVPool(
1739
                    page_size=self.page_size,
Yi Zhang's avatar
Yi Zhang committed
1740
1741
1742
1743
1744
1745
1746
1747
                    size=self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    # if draft worker, we only need 1 attention layer's kv pool
                    full_attention_layer_ids=(
1748
                        [0] if self.is_draft_worker else config.full_attention_layer_ids
Yi Zhang's avatar
Yi Zhang committed
1749
1750
1751
                    ),
                    enable_kvcache_transpose=False,
                    device=self.device,
1752
                    mamba_pool=self.req_to_token_pool.mamba_pool,
Yi Zhang's avatar
Yi Zhang committed
1753
                )
tarinkk's avatar
tarinkk committed
1754
1755
            else:
                self.token_to_kv_pool = MHATokenToKVPool(
Lianmin Zheng's avatar
Lianmin Zheng committed
1756
                    self.max_total_num_tokens,
tarinkk's avatar
tarinkk committed
1757
                    page_size=self.page_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
1758
                    dtype=self.kv_cache_dtype,
tarinkk's avatar
tarinkk committed
1759
1760
1761
1762
1763
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    layer_num=self.num_effective_layers,
Lianmin Zheng's avatar
Lianmin Zheng committed
1764
                    device=self.device,
tarinkk's avatar
tarinkk committed
1765
1766
1767
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
1768
                    enable_alt_stream=not self.server_args.enable_pdmux,
1769
1770
1771
                    enable_kv_cache_copy=(
                        self.server_args.speculative_algorithm is not None
                    ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1772
                )
tarinkk's avatar
tarinkk committed
1773

Lianmin Zheng's avatar
Lianmin Zheng committed
1774
        # Initialize token_to_kv_pool_allocator
Lianmin Zheng's avatar
Lianmin Zheng committed
1775
        need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
tarinkk's avatar
tarinkk committed
1776
        if self.token_to_kv_pool_allocator is None:
1777
            if _is_npu and (
1778
1779
                self.server_args.attention_backend == "ascend"
                or self.hybrid_gdn_config is not None
1780
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1781
1782
1783
1784
1785
1786
1787
1788
                self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                    need_sort=need_sort,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1789
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
                if self.page_size == 1:
                    if self.is_hybrid:
                        self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
                            self.full_max_total_num_tokens,
                            self.swa_max_total_num_tokens,
                            dtype=self.kv_cache_dtype,
                            device=self.device,
                            kvcache=self.token_to_kv_pool,
                            need_sort=need_sort,
                        )
                    else:
                        self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                            self.max_total_num_tokens,
                            dtype=self.kv_cache_dtype,
                            device=self.device,
                            kvcache=self.token_to_kv_pool,
                            need_sort=need_sort,
                        )
1808
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1809
1810
                    assert not self.is_hybrid
                    self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1811
1812
1813
1814
1815
                        self.max_total_num_tokens,
                        page_size=self.page_size,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
1816
                        need_sort=need_sort,
1817
                    )
1818
1819
1820
        else:
            assert self.is_draft_worker

1821
        logger.info(
1822
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
1823
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
1824
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1825

Lianmin Zheng's avatar
Lianmin Zheng committed
1826
1827
1828
1829
1830
1831
1832
1833
1834
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

1835
1836
    def init_attention_backend(self):
        """Init attention kernel backend."""
1837
1838
1839
1840
1841
1842
1843
        if self.server_args.enable_pdmux:
            self.attn_backend = self._get_attention_backend(init_new_workspace=True)
            self.decode_attn_backend_group = []
            for _ in range(self.server_args.sm_group_num):
                self.decode_attn_backend_group.append(self._get_attention_backend())
            self.decode_attn_backend = self.decode_attn_backend_group[0]
        elif self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1844
1845
1846
1847
            self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
        else:
            self.attn_backend = self._get_attention_backend()

1848
    def _get_attention_backend(self, init_new_workspace: bool = False):
1849
        """Init attention kernel backend."""
1850
1851
        self.prefill_attention_backend_str, self.decode_attention_backend_str = (
            self.server_args.get_attention_backends()
1852
        )
1853

1854
1855
1856
1857
1858
1859
        if self.decode_attention_backend_str != self.prefill_attention_backend_str:
            from sglang.srt.layers.attention.hybrid_attn_backend import (
                HybridAttnBackend,
            )

            attn_backend = HybridAttnBackend(
1860
                self,
1861
                decode_backend=self._get_attention_backend_from_str(
1862
1863
                    self.decode_attention_backend_str,
                    init_new_workspace=init_new_workspace,
1864
1865
                ),
                prefill_backend=self._get_attention_backend_from_str(
1866
1867
                    self.prefill_attention_backend_str,
                    init_new_workspace=init_new_workspace,
1868
1869
1870
1871
1872
1873
1874
1875
                ),
            )
            logger.info(
                f"Using hybrid attention backend for decode and prefill: "
                f"decode_backend={self.decode_attention_backend_str}, "
                f"prefill_backend={self.prefill_attention_backend_str}."
            )
            logger.warning(
1876
1877
                "Warning: Attention backend specified by --attention-backend or default backend might be overridden."
                "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1878
1879
1880
            )
        else:
            attn_backend = self._get_attention_backend_from_str(
1881
1882
                self.server_args.attention_backend,
                init_new_workspace=init_new_workspace,
1883
1884
            )

1885
1886
1887
1888
        (
            get_global_server_args().prefill_attention_backend,
            get_global_server_args().decode_attention_backend,
        ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
1889
1890
        return attn_backend

1891
1892
1893
    def _get_attention_backend_from_str(
        self, backend_str: str, init_new_workspace: bool = False
    ):
1894
        if backend_str not in ATTENTION_BACKENDS:
1895
            raise ValueError(f"Invalid attention backend: {backend_str}")
1896
        self.init_new_workspace = init_new_workspace
1897
1898
        full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
        return attn_backend_wrapper(self, full_attention_backend)
1899

Shuo Yang's avatar
Shuo Yang committed
1900
1901
1902
1903
1904
1905
1906
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1907
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1908
1909
1910
1911
1912
1913
1914
1915
1916
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1917
    def init_device_graphs(self):
1918
        """Capture device graphs."""
1919
        self.graph_runner = None
1920
        self.graph_mem_usage = 0
1921

1922
        if not self.is_generation:
1923
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1924
1925
            return

1926
1927
1928
1929
        if self.device != "cpu" and self.server_args.disable_cuda_graph:
            return

        if self.device == "cpu" and not self.server_args.enable_torch_compile:
1930
            return
1931

1932
        tic = time.perf_counter()
1933
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1934
        logger.info(
1935
            f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1936
        )
1937
1938
1939
1940
1941
1942
        graph_runners = defaultdict(
            lambda: CudaGraphRunner,
            {
                "cpu": CPUGraphRunner,
                "npu": NPUGraphRunner,
            },
1943
        )
1944
1945
        self.graph_runner = graph_runners[self.device](self)

1946
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1947
        self.graph_mem_usage = before_mem - after_mem
1948
        logger.info(
1949
1950
            f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
            f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1951
        )
1952

1953
1954
    def init_threads_binding(self):
        omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
1955
1956
        cpu_ids_by_node = get_cpu_ids_by_node()
        n_numa_node = len(cpu_ids_by_node)
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
        if omp_cpuids == "all":
            assert self.tp_size <= n_numa_node, (
                f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
                f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
                f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
                f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
                f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
                f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
                f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
                f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
            )
            if self.tp_size < n_numa_node:
                logger.warning(
                    f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
                )
            self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
        else:
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
            threads_bind_list = omp_cpuids.split("|")
            assert self.tp_size == len(threads_bind_list), (
                f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
                f"Please double check your settings."
            )
            self.local_omp_cpuid = threads_bind_list[self.tp_rank]
            if self.tp_size > n_numa_node:
                logger.warning(
                    f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
                    f"in this case the available memory amount of each rank cannot be determined in prior. "
                    f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
                )
1986

1987
    def apply_torch_tp(self):
1988
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1989
        from sglang.srt.layers.model_parallel import tensor_parallel
1990
1991
1992
1993

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1994
1995
1996
    def update_decode_attn_backend(self, stream_idx: int):
        self.decode_attn_backend = self.decode_attn_backend_group[stream_idx]

1997
    def forward_decode(
Cheng Wan's avatar
Cheng Wan committed
1998
1999
2000
2001
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
2002
    ) -> LogitsProcessorOutput:
Cheng Wan's avatar
Cheng Wan committed
2003
        if not skip_attn_backend_init:
2004
2005
2006
2007
2008
            if self.server_args.enable_pdmux:
                self.decode_attn_backend.init_forward_metadata(forward_batch)
                forward_batch.attn_backend = self.decode_attn_backend
            else:
                self.attn_backend.init_forward_metadata(forward_batch)
2009
2010
2011
2012
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
2013
        return self.model.forward(
2014
2015
2016
2017
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
2018
2019
        )

2020
    def forward_extend(
2021
2022
2023
2024
2025
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
2026
2027
2028
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

2029
2030
2031
2032
2033
2034
2035
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
2036
2037
2038
2039
2040

        if self.piecewise_cuda_graph_runner is not None:
            if self.piecewise_cuda_graph_runner.can_run(forward_batch):
                return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)

2041
2042
2043
2044
2045
2046
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2047

2048
2049
2050
2051
2052
2053
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
2054
        return self.model.forward(
2055
2056
2057
2058
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
2059
2060
        )

2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
    def forward_split_prefill(
        self,
        forward_batch: ForwardBatch,
        reinit_attn_backend: bool = False,
        forward_count: int = 1,
    ) -> LogitsProcessorOutput:
        if forward_batch.split_index == 0 or reinit_attn_backend:
            self.attn_backend.init_forward_metadata(forward_batch)
        next_split_index = min(
            forward_batch.split_index + forward_count,
            self.model_config.num_hidden_layers,
        )
        ret = self.model.forward_split_prefill(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            (forward_batch.split_index, next_split_index),
        )
        forward_batch.split_index = next_split_index
        return ret

2082
    def forward(
2083
2084
2085
2086
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
2087
2088
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
2089
2090
2091
2092
2093
2094
2095
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
        self.forward_pass_id += 1

        with get_global_expert_distribution_recorder().with_forward_pass(
            self.forward_pass_id,
            forward_batch,
        ):
2096
            output = self._forward_raw(
2097
2098
2099
2100
2101
                forward_batch,
                skip_attn_backend_init,
                pp_proxy_tensors,
                reinit_attn_backend,
                split_forward_count,
2102
2103
            )

2104
        if self.eplb_manager is not None:
2105
            self.eplb_manager.on_forward_pass_end()
2106
2107
2108

        return output

2109
2110
2111
2112
2113
    def _forward_raw(
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool,
        pp_proxy_tensors: Optional[PPProxyTensors],
2114
2115
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
2116
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
2117
2118
2119
2120
2121
2122
2123
        mode_check = (
            forward_batch.forward_mode.is_cpu_graph
            if self.device == "cpu"
            else forward_batch.forward_mode.is_cuda_graph
        )
        can_run_graph = bool(
            mode_check()
2124
2125
            and self.graph_runner
            and self.graph_runner.can_run(forward_batch)
2126
        )
2127
2128

        if can_run_graph:
2129
            ret = self.graph_runner.replay(
2130
2131
2132
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
2133
            )
2134
            return ret, can_run_graph
Cheng Wan's avatar
Cheng Wan committed
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145

        # For MLP sync
        if forward_batch.global_num_tokens_cpu is not None:
            forward_batch.prepare_mlp_sync_batch(self)

        if forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
            )
2146
2147
2148
2149
2150
2151
        elif forward_batch.forward_mode.is_split_prefill():
            ret = self.forward_split_prefill(
                forward_batch,
                reinit_attn_backend=reinit_attn_backend,
                forward_count=split_forward_count,
            )
2152
2153
2154
2155
2156
2157
        elif forward_batch.forward_mode.is_extend():
            ret = self.forward_extend(
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
            )
Ke Bao's avatar
Ke Bao committed
2158
        elif forward_batch.forward_mode.is_idle():
2159
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
2160
        else:
2161
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
2162

2163
2164
2165
2166
        if (
            forward_batch.global_num_tokens_cpu is not None
            and self.pp_group.is_last_rank
        ):
Cheng Wan's avatar
Cheng Wan committed
2167
2168
            forward_batch.post_forward_mlp_sync_batch(ret)

2169
        return ret, can_run_graph
2170

2171
2172
2173
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
2174
2175
2176
2177
2178
        # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
        #       was executed after we processed last batch's results.

        # Calculate logits bias and apply it to next_token_logits.
        sampling_info.update_regex_vocab_mask()
2179
2180
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
2201

2202
        self._preprocess_logits(logits_output, forward_batch.sampling_info)
2203
2204
2205
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
2206
            forward_batch.sampling_info,
2207
2208
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
2209
            forward_batch.token_ids_logprobs,
2210
2211
2212
2213
2214
2215
            # For prefill, we only use the position of the last token.
            (
                forward_batch.positions
                if forward_batch.forward_mode.is_decode()
                else forward_batch.seq_lens - 1
            ),
2216
        )
2217
2218
        return next_token_ids

2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
    def compute_logprobs_only(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> None:
        """
        Compute token_ids_logprobs without performing sampling.

        Optimized path for prefill-only requests that need token_ids_logprobs but don't
        require next token generation. Skips expensive sampling operations
        while still providing requested probability information.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output
        """
        if not forward_batch.token_ids_logprobs:
            return

        # Preprocess logits (same as in sample method)
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

        # Delegate to sampler for logprob-only computation
        # This populates logits_output with requested token probabilities
        self.sampler.compute_logprobs_only(
            logits_output,
            forward_batch.sampling_info,
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
            forward_batch.token_ids_logprobs,
        )

Yineng Zhang's avatar
Yineng Zhang committed
2251
2252
2253
2254
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
2255
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
2256
2257
        if rope_scaling is None:
            return False
2258
2259
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
2260

2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
    def update_weights_from_ipc(self, recv_req):
        """Update weights from IPC for checkpoint-engine integration."""
        try:
            from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
                SGLangCheckpointEngineWorkerExtensionImpl,
            )

            # Create a worker extension that integrates with SGLang's model
            worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
            worker.update_weights_from_ipc(recv_req.zmq_handles)
            return True, "IPC weight update completed successfully"
        except ImportError as e:
            return False, f"IPC weight update failed: ImportError {e}"
        except Exception as e:
            logger.error(f"IPC weight update failed: {e}")
            return False, str(e)

2294
2295
2296
2297
2298
2299
2300

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


2301
def _unwrap_tensor(tensor, tp_rank, device):
2302
    if isinstance(tensor, LocalSerializedTensor):
2303
        tensor = tensor.get(tp_rank)
2304
    return tensor.to(device)
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])