model_runner.py 77.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
18
import inspect
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
import time
23
24
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.distributed as dist
28
29
30
31

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
32
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
33
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
34
from sglang.srt.distributed import (
35
    get_pp_group,
zhyncs's avatar
zhyncs committed
36
    get_tp_group,
37
    get_world_group,
zhyncs's avatar
zhyncs committed
38
39
    init_distributed_environment,
    initialize_model_parallel,
40
    set_custom_all_reduce,
41
    set_mscclpp_all_reduce,
zhyncs's avatar
zhyncs committed
42
)
43
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
fzyzcjy's avatar
fzyzcjy committed
44
45
46
47
48
49
50
51
52
53
54
55
56
from sglang.srt.eplb.eplb_manager import EPLBManager
from sglang.srt.eplb.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
    set_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_location import (
    ExpertLocationMetadata,
    compute_initial_expert_location_metadata,
    get_global_expert_location_metadata,
    set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
57
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
58
59
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
60
    get_attention_tp_size,
61
62
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
63
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
64
65
66
from sglang.srt.layers.quantization import (
    deep_gemm_wrapper,
    monkey_patch_isinstance_for_vllm_base_layer,
67
)
68
from sglang.srt.layers.sampler import Sampler
69
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
70
from sglang.srt.lora.lora_manager import LoRAManager
71
from sglang.srt.lora.lora_registry import LoRARef
72
73
74
75
from sglang.srt.managers.schedule_batch import (
    GLOBAL_SERVER_ARGS_KEYS,
    global_server_args_dict,
)
76
77
78
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    PagedTokenToKVPoolAllocator,
tarinkk's avatar
tarinkk committed
79
    SWATokenToKVPoolAllocator,
80
81
    TokenToKVPoolAllocator,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
82
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
83
from sglang.srt.mem_cache.memory_pool import (
84
85
    AscendMLAPagedTokenToKVPool,
    AscendTokenToKVPool,
Shuo Yang's avatar
Shuo Yang committed
86
    DoubleSparseTokenToKVPool,
87
88
89
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
tarinkk's avatar
tarinkk committed
90
    SWAKVPool,
91
)
92
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
93
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
94
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
95
from sglang.srt.model_loader import get_model
96
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
Lianmin Zheng's avatar
Lianmin Zheng committed
97
from sglang.srt.model_loader.utils import set_default_torch_dtype
98
from sglang.srt.model_loader.weight_utils import default_weight_loader
99
100
101
102
103
from sglang.srt.offloader import (
    create_offloader_from_server_args,
    get_offloader,
    set_offloader,
)
104
from sglang.srt.patch_torch import monkey_patch_torch_reductions
105
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
106
from sglang.srt.server_args import ServerArgs
107
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
108
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
109
from sglang.srt.utils import (
110
    MultiprocessingSerializer,
111
    cpu_has_amx_support,
112
    dynamic_import,
113
    enable_show_time_cost,
114
    get_available_gpu_memory,
115
    get_bool_env_var,
116
    get_cpu_ids_by_node,
117
    init_custom_process_group,
118
    is_fa3_default_architecture,
119
    is_flashinfer_available,
HAI's avatar
HAI committed
120
    is_hip,
121
    is_hopper_with_cuda_12_3,
122
    is_no_spec_infer_or_topk_one,
123
    is_npu,
124
    is_sm100_supported,
125
    monkey_patch_p2p_access_check,
126
    monkey_patch_vllm_gguf_config,
127
    set_cuda_arch,
128
)
129
130
131
132
from sglang.srt.weight_sync.tensor_bucket import (
    FlattenedTensorBucket,
    FlattenedTensorMetadata,
)
133

134
_is_hip = is_hip()
135
_is_npu = is_npu()
136
_is_cpu_amx_available = cpu_has_amx_support()
137

Lianmin Zheng's avatar
Lianmin Zheng committed
138
# Use a small KV cache pool size for tests in CI
139
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
140
141

# Detect stragger ranks in model loading
142
143
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
logger = logging.getLogger(__name__)

146

147
148
149
150
151
152
153
154
155
156
157
158
159
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
160
class ModelRunner:
161
162
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
163
164
    def __init__(
        self,
165
        model_config: ModelConfig,
166
167
168
169
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
Cheng Wan's avatar
Cheng Wan committed
170
171
        moe_ep_rank: int,
        moe_ep_size: int,
172
173
        pp_rank: int,
        pp_size: int,
174
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
175
        server_args: ServerArgs,
fzyzcjy's avatar
fzyzcjy committed
176
        dp_rank: Optional[int] = None,
177
        is_draft_worker: bool = False,
178
        req_to_token_pool: Optional[ReqToTokenPool] = None,
179
        token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
180
    ):
181
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
182
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
183
        self.device = server_args.device
184
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
185
186
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Cheng Wan's avatar
Cheng Wan committed
187
188
        self.moe_ep_rank = moe_ep_rank
        self.moe_ep_size = moe_ep_size
189
        self.dp_size = server_args.dp_size
190
191
        self.pp_rank = pp_rank
        self.pp_size = pp_size
192
        self.model_config = model_config
Zhang, Liangang's avatar
Zhang, Liangang committed
193
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
194
        self.server_args = server_args
195
        self.is_draft_worker = is_draft_worker
196
197
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
198
199
200
        self.is_multimodal_chunked_prefill_supported = (
            model_config.is_multimodal_chunked_prefill_supported
        )
201
202
203
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
204
        self.page_size = server_args.page_size
205
206
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
tarinkk's avatar
tarinkk committed
207
        self.is_hybrid = model_config.is_hybrid
Baizhou Zhang's avatar
Baizhou Zhang committed
208
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
209
        self.attention_chunk_size = model_config.attention_chunk_size
210
211
        self.forward_pass_id = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
212
213
214
        # Apply the rank zero filter to logger
        if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
            logger.addFilter(RankZeroFilter(tp_rank == 0))
215
216
        if server_args.show_time_cost:
            enable_show_time_cost()
217

Lianmin Zheng's avatar
Lianmin Zheng committed
218
219
220
        # Model-specific adjustment
        self.model_specific_adjustment()

221
        # Global vars
222
        global_server_args_dict.update(
223
224
225
            {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
            | {
                # TODO it is indeed not a "server args"
226
                "use_mla_backend": self.use_mla_backend,
227
                "speculative_algorithm": self.spec_algorithm,
228
            }
229
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
230

231
232
233
234
        # Init OpenMP threads binding for CPU
        if self.device == "cpu":
            self.init_threads_binding()

235
        # Get memory before model loading
236
        min_per_gpu_memory = self.init_torch_distributed()
237

238
        # CPU offload
fzyzcjy's avatar
fzyzcjy committed
239
        set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
240

241
        # Update deep gemm configure
242
243
        if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
            deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
244

Lianmin Zheng's avatar
Lianmin Zheng committed
245
        # Initialize the model runner
246
247
        self.initialize(min_per_gpu_memory)

Lianmin Zheng's avatar
Lianmin Zheng committed
248
        # Temporary cached values
249
250
251
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
252
253

        # For weight updates
254
        self._model_update_group = {}
255

256
257
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
258

259
260
261
262
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

263
264
265
266
267
268
269
270
        if not self.is_draft_worker:
            set_global_expert_location_metadata(
                compute_initial_expert_location_metadata(server_args, self.model_config)
            )
            if self.tp_rank == 0 and get_bool_env_var(
                "SGLANG_LOG_EXPERT_LOCATION_METADATA"
            ):
                logger.info(
271
                    f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
272
273
274
275
276
277
278
279
280
281
                )

            set_global_expert_distribution_recorder(
                ExpertDistributionRecorder.init_new(
                    server_args,
                    get_global_expert_location_metadata(),
                    rank=self.tp_rank,
                )
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
282
        # Expert parallelism
283
284
285
286
287
        self.eplb_manager = (
            EPLBManager(self)
            if self.server_args.enable_eplb and (not self.is_draft_worker)
            else None
        )
288
        self.expert_location_updater = ExpertLocationUpdater()
289

290
        # Load the model
291
        self.sampler = Sampler()
292
        self.load_model()
293

294
        # Check if the model is using hybrid SWA
Hanming Lu's avatar
Hanming Lu committed
295
296
297
298
299
300
301
302
303
        if (
            not self.server_args.disable_hybrid_swa_memory
            and self.sliding_window_size is not None
            and self.sliding_window_size > 0
        ):
            architectures = self.model_config.hf_config.architectures
            if architectures and not any("Llama4" in arch for arch in architectures):
                self.is_hybrid = self.model_config.is_hybrid = True

304
305
306
307
308
309
310
        # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
        # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
        # determine the number of layers.
        model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
        model_num_layers = (
            self.model_config.num_nextn_predict_layers
            if self.is_draft_worker and model_has_mtp_layers
311
312
313
314
            else max(
                self.model_config.num_hidden_layers,
                self.model_config.num_attention_layers,
            )
315
        )
316
317
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(self.model, "end_layer", model_num_layers)
318
        self.num_effective_layers = self.end_layer - self.start_layer
319
320
321
322
323
324
325
        assert (
            (not model_has_mtp_layers)
            or (self.spec_algorithm.is_none())
            or (
                (not self.spec_algorithm.is_none())
                and (self.num_effective_layers == model_num_layers)
            )
326
        ), "PP is not compatible with MTP models."
327

328
        # Apply torchao quantization
329
330
331
332
333
334
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
335

336
        # Apply torch TP if the model supports it
337
338
339
340
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

341
        # Init lora
342
        if server_args.enable_lora:
343
            self.init_lora_manager()
344

345
346
347
348
349
350
351
352
        # Init Double Sparsity
        if server_args.enable_double_sparsity:
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

353
        # Init memory pool and attention backends
354
355
        self.init_memory_pool(
            min_per_gpu_memory,
356
            server_args.max_running_requests,
357
358
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
359
360
361
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
362
363
364
365
            self.init_device_graphs()
        elif self.device == "npu":
            self.init_attention_backend()
            self.init_device_graphs()
Zhang, Liangang's avatar
Zhang, Liangang committed
366
        else:
367
            self.graph_runner = None
368
            self.cuda_graph_mem_usage = 0
Zhang, Liangang's avatar
Zhang, Liangang committed
369
            self.init_attention_backend()
370

James Liu's avatar
James Liu committed
371
372
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
lukec's avatar
lukec committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
            # load draft config
            draft_model_config = ModelConfig.from_server_args(
                server_args,
                model_path=(server_args.speculative_draft_model_path),
                is_draft_model=True,
            )

            try:
                # get the aux layer from draft model config
                eagle_config = getattr(
                    draft_model_config.hf_config, "eagle_config", None
                )
                eagle_aux_hidden_state_layer_ids = eagle_config[
                    "eagle_aux_hidden_state_layer_ids"
                ]
            except:
                # if there is no aux layer, set to None
                eagle_aux_hidden_state_layer_ids = None

            self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
James Liu's avatar
James Liu committed
393

394
395
396
    def model_specific_adjustment(self):
        server_args = self.server_args

397
398
399
        if (
            server_args.attention_backend == "intel_amx"
            and server_args.device == "cpu"
400
            and not _is_cpu_amx_available
401
402
403
404
405
406
        ):
            logger.info(
                "The current platform does not support Intel AMX, will fallback to torch_native backend."
            )
            server_args.attention_backend = "torch_native"

407
408
409
410
411
412
        if server_args.prefill_attention_backend is not None and (
            server_args.prefill_attention_backend
            == server_args.decode_attention_backend
        ):  # override the default attention backend
            server_args.attention_backend = server_args.prefill_attention_backend

413
414
415
416
417
418
419
420
421
422
423
424
425
        if (
            getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
            is not None
        ):
            if server_args.attention_backend is None:
                server_args.attention_backend = "dual_chunk_flash_attn"
                logger.info("Dual chunk attention is turned on by default.")
            elif server_args.attention_backend != "dual_chunk_flash_attn":
                raise ValueError(
                    "Dual chunk attention is enabled, but attention backend is set to "
                    f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
                )

426
        if server_args.attention_backend is None:
427
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
428
429
            Auto select the fastest attention backend.

430
431
432
433
434
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
435
436
                2.2 We will use Flashinfer backend on blackwell.
                2.3 Otherwise, we will use triton backend.
437
438
            """

439
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
440
                # MHA architecture
441
                if (
442
                    is_hopper_with_cuda_12_3()
443
444
445
446
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
447
448
                elif _is_hip:
                    server_args.attention_backend = "aiter"
449
450
                elif _is_npu:
                    server_args.attention_backend = "ascend"
451
452
453
454
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
455
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
456
                # MLA architecture
457
                if is_hopper_with_cuda_12_3():
458
                    server_args.attention_backend = "fa3"
459
460
                elif is_sm100_supported():
                    server_args.attention_backend = "flashinfer"
461
462
463
464
465
466
467
468
469
                elif _is_hip:
                    head_num = self.model_config.get_num_kv_heads(self.tp_size)
                    # TODO current aiter only support head number 16 or 128 head number
                    if (
                        head_num == 128 or head_num == 16
                    ) and self.spec_algorithm.is_none():
                        server_args.attention_backend = "aiter"
                    else:
                        server_args.attention_backend = "triton"
470
471
                elif _is_npu:
                    server_args.attention_backend = "ascend"
472
473
                else:
                    server_args.attention_backend = "triton"
474
            logger.info(
475
                f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default."
476
            )
477
        elif self.use_mla_backend:
478
            if server_args.device != "cpu":
479
                if server_args.attention_backend in [
480
                    "aiter",
481
482
483
484
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
485
                    "cutlass_mla",
486
                    "trtllm_mla",
487
                    "ascend",
488
                ]:
489
490
491
                    logger.info(
                        f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                    )
492
                else:
493
494
495
496
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
497
498
499
500
                if server_args.attention_backend != "intel_amx":
                    raise ValueError(
                        "MLA optimization not supported on CPU except for intel_amx backend."
                    )
501

502
503
504
505
506
507
508
509
510
511
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

512
        if server_args.enable_double_sparsity:
513
514
515
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
516
517
518
519
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True

        if self.is_multimodal:
520
521
522
            if not self.is_multimodal_chunked_prefill_supported:
                server_args.chunked_prefill_size = -1
                logger.info(
523
                    f"Automatically turn off --chunked-prefill-size as it is not supported for "
524
525
                    f"{self.model_config.hf_config.model_type}"
                )
526

527
528
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
529
530
531
532
533
534
535
536
537
538
539
        # TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
        #  For more details, see: https://github.com/sgl-project/sglang/issues/8616
        elif (
            self.dp_size > 1
            and is_sm100_supported()
            and server_args.attention_backend != "triton"
        ):
            logger.info(
                "Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
            )
            server_args.disable_chunked_prefix_cache = True
540
541

        if not server_args.disable_chunked_prefix_cache:
542
            logger.info("Chunked prefix cache is turned on.")
543

kk's avatar
kk committed
544
545
546
547
        if server_args.attention_backend == "aiter":
            if self.model_config.context_len > 8192:
                self.mem_fraction_static *= 0.85

548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        if (
            server_args.enable_hierarchical_cache
            and server_args.hicache_io_backend == "kernel"
        ):
            # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
            if server_args.decode_attention_backend is None:
                if not self.use_mla_backend:
                    server_args.decode_attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
                else:
                    server_args.decode_attention_backend = (
                        "flashinfer" if is_sm100_supported() else "triton"
                    )
            elif server_args.decode_attention_backend == "fa3":
                server_args.hicache_io_backend = "direct"
                logger.warning(
                    "FlashAttention3 decode backend is not compatible with hierarchical cache. "
                    f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
                )

569
    def init_torch_distributed(self):
570
        logger.info("Init torch distributed begin.")
571

572
573
574
575
576
577
578
579
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
580
581
        if self.device == "cuda":
            backend = "nccl"
582
        elif self.device == "xpu":
583
            backend = "xccl"
584
585
        elif self.device == "hpu":
            backend = "hccl"
586
587
        elif self.device == "cpu":
            backend = "gloo"
588
589
        elif self.device == "npu":
            backend = "hccl"
590

591
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
592
        if not self.server_args.enable_p2p_check:
593
594
            monkey_patch_p2p_access_check()

595
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
596
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
597
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
598
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
599
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
600
        set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
601
602

        if not self.is_draft_worker:
603
604
605
606
            if self.device == "cpu":
                if _is_cpu_amx_available:
                    # Bind OpenMP threads to CPU cores
                    torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
607
608
609
610

                    # Set local size to hint SGLang to use shared memory based AllReduce
                    os.environ["LOCAL_SIZE"] = str(self.tp_size)
                    torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
611
612
                else:
                    logger.warning(
613
                        "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
614
615
                    )

Mick's avatar
Mick committed
616
            # Only initialize the distributed environment on the target model worker.
617
618
            init_distributed_environment(
                backend=backend,
619
620
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
621
622
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
623
                timeout=self.server_args.dist_timeout,
624
            )
625
626
627
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
Cheng Wan's avatar
Cheng Wan committed
628
                expert_model_parallel_size=self.moe_ep_size,
629
                duplicate_tp_group=self.server_args.enable_pdmux,
630
            )
631
            initialize_dp_attention(
632
633
                server_args=self.server_args,
                model_config=self.model_config,
634
            )
635

636
        min_per_gpu_memory = get_available_gpu_memory(
637
638
639
640
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
641
        )
642
        self.tp_group = get_tp_group()
643
        self.pp_group = get_pp_group()
644
        self.attention_tp_group = get_attention_tp_group()
645

646
        # Check memory for tensor parallelism
647
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
648
        if self.tp_size > 1 and not self.is_draft_worker:
649
            if min_per_gpu_memory < local_gpu_memory * 0.9:
650
651
652
653
654
655
656
657
658
659
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
660

661
662
663
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
664
        return min_per_gpu_memory
665

Lianmin Zheng's avatar
Lianmin Zheng committed
666
    def load_model(self):
667
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
668
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
669
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
670
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
671
672

        # This can reduce thread conflicts and speed up weight loading.
673
674
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
675
676
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
677
678
679
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
680
                self.server_args.dtype = "float16"
681
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
682
683
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
684

685
686
        set_cuda_arch()

687
        # Prepare the model config
688
689
690
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
691
            model_loader_extra_config=self.server_args.model_loader_extra_config,
692
        )
693
694
695
696
        if self.device == "cpu":
            self.model_config = adjust_config_with_unaligned_cpu_tp(
                self.model_config, self.load_config, self.tp_size
            )
697
698
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
699
700

        # Load the model
701
702
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
703
704
        monkey_patch_isinstance_for_vllm_base_layer()

705
        with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
706
707
708
709
710
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
711
        monkey_patch_vllm_parallel_state(reverse=True)
712
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
713

714
715
        get_offloader().post_init()

bjmsong's avatar
bjmsong committed
716
717
718
719
720
721
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
722
723
724
725
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
726
727
728
729
730
731
732
733
734
735
736
737
738
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

739
        # Parse other args
Hanming Lu's avatar
Hanming Lu committed
740
741
742
743
744
        self.sliding_window_size = None
        if hasattr(self.model, "get_attention_sliding_window_size"):
            self.sliding_window_size = self.model.get_attention_sliding_window_size()
        elif self.model_config.attention_chunk_size is not None:
            self.sliding_window_size = self.model_config.attention_chunk_size
745
            logger.info(
Hanming Lu's avatar
Hanming Lu committed
746
747
748
                f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
            )

749
        self.dtype = self.model_config.dtype
750

751
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
752
        self.weight_load_mem_usage = before_avail_memory - after_avail_memory
753
        logger.info(
754
            f"Load weight end. "
755
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
756
            f"dtype={self.dtype}, "
757
            f"avail mem={after_avail_memory:.2f} GB, "
758
            f"mem usage={self.weight_load_mem_usage:.2f} GB."
759
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
760

761
762
763
764
765
766
767
768
769
770
771
772
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

773
    def update_expert_location(
774
775
776
        self,
        new_expert_location_metadata: ExpertLocationMetadata,
        update_layer_ids: List[int],
777
    ):
778
        self.expert_location_updater.update(
779
780
            self.model.routed_experts_weights_of_layer,
            new_expert_location_metadata,
781
            update_layer_ids=update_layer_ids,
782
783
784
785
            nnodes=self.server_args.nnodes,
            rank=self.tp_rank,
        )

786
787
788
789
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
790
        logger.info(
Chayenne's avatar
Chayenne committed
791
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
792
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
793
794
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
795
        target_device = torch.device(self.device)
796
        self.model_config.model_path = model_path
797
798
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
799
        # Only support DefaultModelLoader for now
800
801
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
802
803
            message = f"Failed to get model loader: {loader}."
            return False, message
804
805
806

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
807
                DefaultModelLoader.Source.init_new(config, self.model)
808
809
810
811
            )
            return iter

        def model_load_weights(model, iter):
812
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
813
814
            return model

815
        with set_default_torch_dtype(self.model_config.dtype):
816
            try:
817
                iter = get_weight_iter(self.model_config)
818
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
819
                message = f"Failed to get weights iterator: {e}."
820
821
822
823
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
824
825
826
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
827
828
                del iter
                gc.collect()
829
                iter = get_weight_iter(self.model_config)
830
831
832
833
834
835
836
837
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

838
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
839
        return True, "Succeeded to update model weights."
840

841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
869
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
870
871
872
        )

        try:
873
            self._model_update_group[group_name] = init_custom_process_group(
874
875
876
877
878
879
880
881
882
883
884
885
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

886
    def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
887
888
889
890
891
892
893
894
895
896
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """

897
898
899
900
        assert group_name in self._model_update_group, (
            f"Group {group_name} not in {list(self._model_update_group.keys())}. "
            "Please call `init_weights_update_group` first."
        )
901
902

        try:
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
            weights = []
            handles = []
            for name, dtype, shape in zip(names, dtypes, shapes):
                target_dtype = (
                    dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
                )
                weight = torch.empty(shape, dtype=target_dtype, device=self.device)
                handles.append(
                    torch.distributed.broadcast(
                        weight,
                        src=0,
                        group=self._model_update_group[group_name],
                        async_op=True,
                    )
                )
                weights.append((name, weight))
            for handle in handles:
                handle.wait()

            self.model.load_weights(weights)
            return True, f"Succeeded to update parameter online."
924
925
926
927
928
929
930
931
932
933

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

934
935
936
937
938
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
939
        monkey_patch_torch_reductions()
940
941
942
943
944
945
        if load_format == "flattened_bucket":
            # Handle flattened bucket format
            return self._update_weights_from_flattened_bucket(
                flattened_tensor_bucket_dict=named_tensors
            )

946
        # We need to get device after patch otherwise the device would be wrong
947
948
        self.device_module = torch.get_device_module(self.device)
        infered_device = self.device_module.current_device()
949

950
        named_tensors = [
951
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
952
953
954
955
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
956
957
958
        elif load_format in self.server_args.custom_weight_loader:
            custom_loader = dynamic_import(load_format)
            custom_loader(self.model, named_tensors)
959
960
961
962
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
963
        return True, "Success"
964

965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
    def _update_weights_from_flattened_bucket(
        self,
        flattened_tensor_bucket_dict,
    ):
        """Handle flattened bucket format for weight updates"""
        flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
        metadata = flattened_tensor_bucket_dict["metadata"]

        # Convert metadata dict to our format
        converted_metadata = []
        for meta in metadata:
            converted_meta = FlattenedTensorMetadata(
                name=meta.name,
                shape=meta.shape,
                dtype=meta.dtype,
                start_idx=meta.start_idx,
                end_idx=meta.end_idx,
                numel=meta.numel,
            )
            converted_metadata.append(converted_meta)

        # Create bucket and reconstruct tensors
        bucket = FlattenedTensorBucket(
            flattened_tensor=flattened_tensor, metadata=converted_metadata
        )
        reconstructed_tensors = bucket.reconstruct_tensors()

        # Load the reconstructed tensors using the standard method
        self.model.load_weights(reconstructed_tensors)

        return True, "Success"

997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

1014
1015
1016
1017
1018
1019
1020
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
1021
            lora_backend=self.server_args.lora_backend,
1022
1023
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
1024
1025
            max_lora_rank=self.server_args.max_lora_rank,
            target_modules=self.server_args.lora_target_modules,
1026
            lora_paths=self.server_args.lora_paths,
1027
        )
1028

1029
    def load_lora_adapter(self, lora_ref: LoRARef):
1030
1031
1032
        """Load a new lora adapter from disk or huggingface."""

        logger.info(
1033
            f"LoRA adapter loading starts: {lora_ref}. "
1034
1035
1036
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1037
        result = self.lora_manager.load_lora_adapter(lora_ref)
1038
1039

        logger.info(
1040
            f"LoRA adapter loading completes: {lora_ref}. "
1041
1042
1043
1044
1045
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result

1046
    def unload_lora_adapter(self, lora_ref: LoRARef):
1047
1048
1049
        """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""

        logger.info(
1050
            f"LoRA adapter unloading starts: {lora_ref}. "
1051
1052
1053
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1054
        result = self.lora_manager.unload_lora_adapter(lora_ref)
1055
1056

        logger.info(
1057
            f"LoRA adapter unloading completes: {lora_ref}. "
1058
1059
1060
1061
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result
1062

1063
    def profile_max_num_token(self, total_gpu_memory: int):
1064
        available_gpu_memory = get_available_gpu_memory(
1065
1066
1067
1068
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
1069
        )
1070
1071
1072
1073
1074
        if self.is_draft_worker:
            num_layers = getattr(
                self.model_config.hf_config,
                "num_nextn_predict_layers",
                self.num_effective_layers,
1075
            )
1076
1077
1078
        else:
            num_layers = self.num_effective_layers
        if self.use_mla_backend:
1079
1080
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
1081
                * num_layers
1082
                * torch._utils._element_size(self.kv_cache_dtype)
1083
1084
1085
            )
        else:
            cell_size = (
1086
                self.model_config.get_num_kv_heads(get_attention_tp_size())
1087
                * self.model_config.head_dim
1088
                * num_layers
1089
                * 2
1090
                * torch._utils._element_size(self.kv_cache_dtype)
1091
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1092
1093
1094
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
1095
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1096
1097
        return max_num_token

tarinkk's avatar
tarinkk committed
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
    def set_num_token_hybrid(self):
        if (
            "Llama4ForConditionalGeneration"
            in self.model_config.hf_config.architectures
        ):
            temp_ratio = (
                (1 - self.is_hybrid)
                + self.is_hybrid
                * self.attention_chunk_size
                / self.model_config.context_len
            )
            self.swa_max_total_num_tokens = (
                4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.full_max_total_num_tokens = (
                4 * self.max_total_num_tokens
                - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.swa_max_total_num_tokens = int(
                self.swa_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.full_max_total_num_tokens = int(
                self.full_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens
        else:
Hanming Lu's avatar
Hanming Lu committed
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
            assert self.sliding_window_size is not None and self.sliding_window_size > 0
            full_attention_layer_ids = []
            swa_attention_layer_ids = []

            try:
                layers = self.model.model.layers
            except:
                try:
                    layers = self.model.language_model.model.layers
                except:
1138
1139
1140
1141
1142
                    try:
                        layers = self.model.language_model.layers
                    except:
                        self.is_hybrid = False
                        return
Hanming Lu's avatar
Hanming Lu committed
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177

            for layer in layers:
                if (
                    layer.self_attn.attn.sliding_window_size is None
                    or layer.self_attn.attn.sliding_window_size == -1
                ):
                    full_attention_layer_ids.append(layer.layer_id)
                else:
                    swa_attention_layer_ids.append(layer.layer_id)
            self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
            self.model_config.full_attention_layer_ids = full_attention_layer_ids

            # Algorithm:
            # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
            # - Find total # of tokens available across layers.
            # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
            total_tokens = (
                self.max_total_num_tokens * self.model_config.num_hidden_layers
            )
            full_layers_num = len(full_attention_layer_ids)
            swa_layers_num = len(swa_attention_layer_ids)
            swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio

            # Solve the equations:
            # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
            # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
            denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
            self.full_max_total_num_tokens = int(total_tokens / denominator)
            self.swa_max_total_num_tokens = int(
                self.full_max_total_num_tokens * swa_full_tokens_ratio
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens

            logger.info(
                f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
tarinkk's avatar
tarinkk committed
1178
1179
            )

1180
    def init_memory_pool(
1181
1182
        self,
        total_gpu_memory: int,
1183
1184
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
1185
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1186
        # Determine the kv cache dtype
1187
1188
1189
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1190
            if _is_hip:  # Using natively supported format
HAI's avatar
HAI committed
1191
1192
1193
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
1194
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
1195
1196
1197
            if _is_hip:  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e4m3fnuz
            else:
bjmsong's avatar
bjmsong committed
1198
                self.kv_cache_dtype = torch.float8_e4m3fn
1199
1200
1201
1202
1203
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

1204
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
Lianmin Zheng's avatar
Lianmin Zheng committed
1205
1206
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1222
                max_num_reqs = self.server_args.max_num_reqs
1223
            else:
1224
1225
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
1226
1227
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
1228
1229
1230
1231
1232
1233
1234
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
1235
1236
                    + 100
                )
1237
1238
1239
1240
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
1241

1242
1243
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
1244
                logging.warning(
1245
1246
1247
1248
1249
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
1250

1251
1252
1253
1254
1255
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )
tarinkk's avatar
tarinkk committed
1256
1257
1258
1259
        # create token size for hybrid cache
        if self.is_hybrid:
            self.set_num_token_hybrid()

1260
        if self.max_total_num_tokens <= 0:
1261
            raise RuntimeError(
1262
                "Not enough memory. Please try to increase --mem-fraction-static."
1263
            )
1264

Lianmin Zheng's avatar
Lianmin Zheng committed
1265
        # Initialize req_to_token_pool
1266
        if self.req_to_token_pool is None:
1267
1268
1269
1270
1271
            # FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
            extra_max_context_len = 4
            if self.server_args.speculative_num_draft_tokens is not None:
                extra_max_context_len += self.server_args.speculative_num_draft_tokens

Byron Hsu's avatar
Byron Hsu committed
1272
1273
1274
1275
1276
1277
1278
1279
            if self.server_args.disaggregation_mode == "decode":
                from sglang.srt.disaggregation.decode import DecodeReqToTokenPool

                # subscribe memory for pre-allocated requests
                # if max_num_reqs <= 32, we pre-allocate 2x requests
                pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
                self.req_to_token_pool = DecodeReqToTokenPool(
                    size=max_num_reqs,
1280
1281
                    max_context_len=self.model_config.context_len
                    + extra_max_context_len,
Byron Hsu's avatar
Byron Hsu committed
1282
1283
1284
1285
1286
1287
1288
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    pre_alloc_size=pre_alloc_size,
                )
            else:
                self.req_to_token_pool = ReqToTokenPool(
                    size=max_num_reqs,
1289
1290
                    max_context_len=self.model_config.context_len
                    + extra_max_context_len,
Byron Hsu's avatar
Byron Hsu committed
1291
1292
1293
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
1294
1295
1296
1297
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

Lianmin Zheng's avatar
Lianmin Zheng committed
1298
        # Initialize token_to_kv_pool
Lianmin Zheng's avatar
Lianmin Zheng committed
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
        if self.server_args.attention_backend == "ascend":
            if self.use_mla_backend:
                self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    kv_lora_rank=self.model_config.kv_lora_rank,
                    qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                    layer_num=self.num_effective_layers,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
                )
            else:
                self.token_to_kv_pool = AscendTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    layer_num=self.model_config.num_hidden_layers,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
1326
        elif self.use_mla_backend:
1327
1328
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1329
                page_size=self.page_size,
1330
                dtype=self.kv_cache_dtype,
1331
1332
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1333
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
1334
                device=self.device,
1335
                enable_memory_saver=self.server_args.enable_memory_saver,
1336
1337
                start_layer=self.start_layer,
                end_layer=self.end_layer,
1338
            )
Shuo Yang's avatar
Shuo Yang committed
1339
1340
1341
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1342
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
1343
                dtype=self.kv_cache_dtype,
1344
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
1345
                head_dim=self.model_config.head_dim,
1346
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
1347
1348
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
1349
                enable_memory_saver=self.server_args.enable_memory_saver,
1350
1351
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
1352
            )
1353
        else:
tarinkk's avatar
tarinkk committed
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
            if self.is_hybrid:
                self.token_to_kv_pool = SWAKVPool(
                    size=self.full_max_total_num_tokens,
                    size_swa=self.swa_max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
                    full_attention_layer_ids=self.model_config.full_attention_layer_ids,
                    enable_kvcache_transpose=False,
                    device=self.device,
                )
            else:
                self.token_to_kv_pool = MHATokenToKVPool(
Lianmin Zheng's avatar
Lianmin Zheng committed
1370
                    self.max_total_num_tokens,
tarinkk's avatar
tarinkk committed
1371
                    page_size=self.page_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
1372
                    dtype=self.kv_cache_dtype,
tarinkk's avatar
tarinkk committed
1373
1374
1375
1376
1377
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    layer_num=self.num_effective_layers,
Lianmin Zheng's avatar
Lianmin Zheng committed
1378
                    device=self.device,
tarinkk's avatar
tarinkk committed
1379
1380
1381
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
Lianmin Zheng's avatar
Lianmin Zheng committed
1382
                )
tarinkk's avatar
tarinkk committed
1383

Lianmin Zheng's avatar
Lianmin Zheng committed
1384
        # Initialize token_to_kv_pool_allocator
Lianmin Zheng's avatar
Lianmin Zheng committed
1385
        need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
tarinkk's avatar
tarinkk committed
1386
        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1387
1388
1389
1390
1391
1392
1393
1394
1395
            if self.server_args.attention_backend == "ascend":
                self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                    need_sort=need_sort,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1396
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
                if self.page_size == 1:
                    if self.is_hybrid:
                        self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
                            self.full_max_total_num_tokens,
                            self.swa_max_total_num_tokens,
                            dtype=self.kv_cache_dtype,
                            device=self.device,
                            kvcache=self.token_to_kv_pool,
                            need_sort=need_sort,
                        )
                    else:
                        self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                            self.max_total_num_tokens,
                            dtype=self.kv_cache_dtype,
                            device=self.device,
                            kvcache=self.token_to_kv_pool,
                            need_sort=need_sort,
                        )
1415
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1416
1417
                    assert not self.is_hybrid
                    self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1418
1419
1420
1421
1422
                        self.max_total_num_tokens,
                        page_size=self.page_size,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
1423
                        need_sort=need_sort,
1424
                    )
1425
1426
1427
        else:
            assert self.is_draft_worker

1428
        logger.info(
1429
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
1430
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
1431
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1432

Lianmin Zheng's avatar
Lianmin Zheng committed
1433
1434
1435
1436
1437
1438
1439
1440
1441
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

1442
1443
    def init_attention_backend(self):
        """Init attention kernel backend."""
1444
        if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1445
1446
1447
1448
1449
            self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
        else:
            self.attn_backend = self._get_attention_backend()

    def _get_attention_backend(self):
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
        """Init attention kernel backend."""
        self.decode_attention_backend_str = (
            self.server_args.decode_attention_backend
            if self.server_args.decode_attention_backend
            else self.server_args.attention_backend
        )
        self.prefill_attention_backend_str = (
            self.server_args.prefill_attention_backend
            if self.server_args.prefill_attention_backend
            else self.server_args.attention_backend
        )
        if self.decode_attention_backend_str != self.prefill_attention_backend_str:
            from sglang.srt.layers.attention.hybrid_attn_backend import (
                HybridAttnBackend,
            )

            attn_backend = HybridAttnBackend(
1467
                self,
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
                decode_backend=self._get_attention_backend_from_str(
                    self.decode_attention_backend_str
                ),
                prefill_backend=self._get_attention_backend_from_str(
                    self.prefill_attention_backend_str
                ),
            )
            logger.info(
                f"Using hybrid attention backend for decode and prefill: "
                f"decode_backend={self.decode_attention_backend_str}, "
                f"prefill_backend={self.prefill_attention_backend_str}."
            )
            logger.warning(
                f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
                f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
            )
        else:
            attn_backend = self._get_attention_backend_from_str(
                self.server_args.attention_backend
            )

        global_server_args_dict.update(
            {
                "decode_attention_backend": self.decode_attention_backend_str,
                "prefill_attention_backend": self.prefill_attention_backend_str,
            }
        )
        return attn_backend

    def _get_attention_backend_from_str(self, backend_str: str):
        if backend_str == "flashinfer":
1499
1500
1501
1502
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
1503

1504
1505
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
1506
1507
1508
1509
1510
                    if (
                        not hasattr(self, "plan_stream_for_flashinfer")
                        or not self.plan_stream_for_flashinfer
                    ):
                        self.plan_stream_for_flashinfer = torch.cuda.Stream()
1511
                return FlashInferAttnBackend(self)
1512
1513
1514
1515
1516
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

1517
                return FlashInferMLAAttnBackend(self)
1518
        elif backend_str == "aiter":
1519
1520
            from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend

1521
            return AiterAttnBackend(self)
1522
1523
1524
1525
        elif self.server_args.attention_backend == "wave":
            from sglang.srt.layers.attention.wave_backend import WaveAttnBackend

            return WaveAttnBackend(self)
1526
        elif backend_str == "ascend":
1527
1528
1529
            from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend

            return AscendAttnBackend(self)
1530
        elif backend_str == "triton":
1531
1532
1533
1534
1535
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
1536
1537
1538
1539
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

1540
                return DoubleSparseAttnBackend(self)
1541
            else:
1542
1543
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

1544
                return TritonAttnBackend(self)
1545
        elif backend_str == "torch_native":
1546
1547
1548
1549
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

1550
            return TorchNativeAttnBackend(self)
1551
        elif backend_str == "flashmla":
lukec's avatar
lukec committed
1552
1553
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

1554
            return FlashMLABackend(self)
1555
        elif backend_str == "fa3":
1556
1557
1558
1559
            assert (
                torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
            ) or torch.cuda.get_device_capability()[0] == 9, (
                "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1560
1561
1562
1563
1564
1565
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

1566
            return FlashAttentionBackend(self)
1567
        elif backend_str == "cutlass_mla":
1568
1569
1570
1571
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

1572
            return CutlassMLABackend(self)
1573
        elif backend_str == "trtllm_mla":
1574
1575
1576
1577
1578
            if not self.use_mla_backend:
                raise ValueError("trtllm_mla backend can only be used with MLA models.")
            from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend

            return TRTLLMMLABackend(self)
1579
        elif backend_str == "trtllm_mha":
1580
1581
1582
1583
1584
1585
1586
1587
1588
            if self.use_mla_backend:
                raise ValueError(
                    "trtllm_mha backend can only be used with non-MLA models."
                )
            from sglang.srt.layers.attention.trtllm_mha_backend import (
                TRTLLMHAAttnBackend,
            )

            return TRTLLMHAAttnBackend(self)
1589
        elif backend_str == "intel_amx":
1590
1591
1592
1593
1594
            from sglang.srt.layers.attention.intel_amx_backend import (
                IntelAMXAttnBackend,
            )

            return IntelAMXAttnBackend(self)
Lianmin Zheng's avatar
Lianmin Zheng committed
1595
        elif backend_str == "dual_chunk_flash_attn":
1596
1597
1598
1599
1600
            from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
                DualChunkFlashAttentionBackend,
            )

            return DualChunkFlashAttentionBackend(self)
1601
        else:
1602
            raise ValueError(f"Invalid attention backend: {backend_str}")
1603

Shuo Yang's avatar
Shuo Yang committed
1604
1605
1606
1607
1608
1609
1610
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1611
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1612
1613
1614
1615
1616
1617
1618
1619
1620
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1621
    def init_device_graphs(self):
1622
        """Capture cuda graphs."""
1623
        self.graph_runner = None
1624
        self.cuda_graph_mem_usage = 0
1625

1626
        if not self.is_generation:
1627
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1628
1629
            return

1630
1631
        if self.server_args.disable_cuda_graph:
            return
1632

1633
        tic = time.perf_counter()
1634
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1635
        logger.info(
1636
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1637
        )
1638
1639
1640
        self.graph_runner = (
            CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
        )
1641
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1642
        self.cuda_graph_mem_usage = before_mem - after_mem
1643
        logger.info(
1644
            f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1645
            f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1646
        )
1647

1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
    def init_threads_binding(self):
        omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
        if omp_cpuids == "all":
            cpu_ids_by_node = get_cpu_ids_by_node()
            n_numa_node = len(cpu_ids_by_node)

            assert self.tp_size <= n_numa_node, (
                f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
                f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
                f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
                f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
                f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
                f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
                f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
                f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
            )
            if self.tp_size < n_numa_node:
                logger.warning(
                    f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
                )
            self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
        else:
            self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]

1672
    def apply_torch_tp(self):
1673
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1674
        from sglang.srt.layers.model_parallel import tensor_parallel
1675
1676
1677
1678

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1679
    def forward_decode(
Cheng Wan's avatar
Cheng Wan committed
1680
1681
1682
1683
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
1684
    ) -> LogitsProcessorOutput:
Cheng Wan's avatar
Cheng Wan committed
1685
1686
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)
1687
1688
1689
1690
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1691
        return self.model.forward(
1692
1693
1694
1695
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
1696
1697
        )

1698
    def forward_extend(
1699
1700
1701
1702
1703
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
1704
1705
1706
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1720

1721
1722
1723
1724
1725
1726
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
1727
        return self.model.forward(
1728
1729
1730
1731
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
1732
1733
        )

1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
    def forward_split_prefill(
        self,
        forward_batch: ForwardBatch,
        reinit_attn_backend: bool = False,
        forward_count: int = 1,
    ) -> LogitsProcessorOutput:
        if forward_batch.split_index == 0 or reinit_attn_backend:
            self.attn_backend.init_forward_metadata(forward_batch)
        next_split_index = min(
            forward_batch.split_index + forward_count,
            self.model_config.num_hidden_layers,
        )
        ret = self.model.forward_split_prefill(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            (forward_batch.split_index, next_split_index),
        )
        forward_batch.split_index = next_split_index
        return ret

1755
    def forward(
1756
1757
1758
1759
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
1760
1761
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
1762
1763
1764
1765
1766
1767
1768
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
        self.forward_pass_id += 1

        with get_global_expert_distribution_recorder().with_forward_pass(
            self.forward_pass_id,
            forward_batch,
        ):
1769
            output = self._forward_raw(
1770
1771
1772
1773
1774
                forward_batch,
                skip_attn_backend_init,
                pp_proxy_tensors,
                reinit_attn_backend,
                split_forward_count,
1775
1776
            )

1777
        if self.eplb_manager is not None:
1778
            self.eplb_manager.on_forward_pass_end()
1779
1780
1781

        return output

1782
1783
1784
1785
1786
    def _forward_raw(
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool,
        pp_proxy_tensors: Optional[PPProxyTensors],
1787
1788
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
1789
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1790
        can_run_cuda_graph = bool(
1791
            forward_batch.forward_mode.is_cuda_graph()
1792
1793
            and self.graph_runner
            and self.graph_runner.can_run(forward_batch)
1794
1795
        )
        if can_run_cuda_graph:
1796
            ret = self.graph_runner.replay(
1797
1798
1799
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1800
            )
Cheng Wan's avatar
Cheng Wan committed
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
            return ret, can_run_cuda_graph

        # For MLP sync
        if forward_batch.global_num_tokens_cpu is not None:
            forward_batch.prepare_mlp_sync_batch(self)

        if forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
            )
1813
        elif forward_batch.forward_mode.is_extend():
1814
            ret = self.forward_extend(
1815
1816
1817
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1818
            )
1819
1820
1821
1822
1823
1824
        elif forward_batch.forward_mode.is_split_prefill():
            ret = self.forward_split_prefill(
                forward_batch,
                reinit_attn_backend=reinit_attn_backend,
                forward_count=split_forward_count,
            )
Ke Bao's avatar
Ke Bao committed
1825
        elif forward_batch.forward_mode.is_idle():
1826
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
1827
        else:
1828
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1829

1830
1831
1832
1833
        if (
            forward_batch.global_num_tokens_cpu is not None
            and self.pp_group.is_last_rank
        ):
Cheng Wan's avatar
Cheng Wan committed
1834
1835
            forward_batch.post_forward_mlp_sync_batch(ret)

1836
1837
        return ret, can_run_cuda_graph

1838
1839
1840
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1841
        # Apply logit bias
1842
1843
1844
1845
1846
1847
1848
1849
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1850
1851
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
1872

1873
1874
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1875
1876
1877
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1878
            forward_batch.sampling_info,
1879
1880
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1881
            forward_batch.token_ids_logprobs,
1882
        )
1883
1884
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1885
1886
1887
1888
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
1889
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
1890
1891
        if rope_scaling is None:
            return False
1892
1893
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1894

1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1911
1912
1913
1914
1915
1916
1917

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


1918
def _unwrap_tensor(tensor, tp_rank, device):
1919
    if isinstance(tensor, LocalSerializedTensor):
1920
        tensor = tensor.get(tp_rank)
1921
    return tensor.to(device)
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])