model_runner.py 76.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
18
import inspect
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
import time
23
24
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.distributed as dist
28
29
30
31

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
32
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
33
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
34
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
35
    get_tp_group,
36
    get_world_group,
zhyncs's avatar
zhyncs committed
37
38
    init_distributed_environment,
    initialize_model_parallel,
39
    set_custom_all_reduce,
40
    set_mscclpp_all_reduce,
zhyncs's avatar
zhyncs committed
41
)
42
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
fzyzcjy's avatar
fzyzcjy committed
43
44
45
46
47
48
49
50
51
52
53
54
55
from sglang.srt.eplb.eplb_manager import EPLBManager
from sglang.srt.eplb.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
    set_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_location import (
    ExpertLocationMetadata,
    compute_initial_expert_location_metadata,
    get_global_expert_location_metadata,
    set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
56
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
57
58
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
59
    get_attention_tp_size,
60
61
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
62
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
63
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
64
65
66
from sglang.srt.layers.quantization import (
    deep_gemm_wrapper,
    monkey_patch_isinstance_for_vllm_base_layer,
67
)
68
from sglang.srt.layers.sampler import Sampler
69
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
70
from sglang.srt.layers.utils import is_sm100_supported
71
from sglang.srt.lora.lora_manager import LoRAManager
72
from sglang.srt.lora.lora_registry import LoRARef
73
74
75
76
from sglang.srt.managers.schedule_batch import (
    GLOBAL_SERVER_ARGS_KEYS,
    global_server_args_dict,
)
77
from sglang.srt.mem_cache.allocator import (
78
    AscendPagedTokenToKVPoolAllocator,
79
80
    BaseTokenToKVPoolAllocator,
    PagedTokenToKVPoolAllocator,
tarinkk's avatar
tarinkk committed
81
    SWATokenToKVPoolAllocator,
82
83
    TokenToKVPoolAllocator,
)
84
from sglang.srt.mem_cache.memory_pool import (
85
86
    AscendMLAPagedTokenToKVPool,
    AscendTokenToKVPool,
Shuo Yang's avatar
Shuo Yang committed
87
    DoubleSparseTokenToKVPool,
88
89
90
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
tarinkk's avatar
tarinkk committed
91
    SWAKVPool,
92
)
Yineng Zhang's avatar
Yineng Zhang committed
93
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
94
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
95
from sglang.srt.model_loader import get_model
96
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
Lianmin Zheng's avatar
Lianmin Zheng committed
97
from sglang.srt.model_loader.utils import set_default_torch_dtype
98
from sglang.srt.model_loader.weight_utils import default_weight_loader
99
from sglang.srt.patch_torch import monkey_patch_torch_reductions
100
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
101
from sglang.srt.server_args import ServerArgs
102
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
103
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
104
from sglang.srt.utils import (
105
    MultiprocessingSerializer,
106
    cpu_has_amx_support,
107
    dynamic_import,
108
    enable_show_time_cost,
109
    get_available_gpu_memory,
110
    get_bool_env_var,
111
    get_cpu_ids_by_node,
112
    init_custom_process_group,
113
    is_fa3_default_architecture,
114
    is_flashinfer_available,
HAI's avatar
HAI committed
115
    is_hip,
116
    is_hopper_with_cuda_12_3,
117
    is_no_spec_infer_or_topk_one,
118
    is_npu,
119
    monkey_patch_p2p_access_check,
120
    monkey_patch_vllm_gguf_config,
121
    set_cpu_offload_max_bytes,
122
    set_cuda_arch,
123
)
124
125
126
127
from sglang.srt.weight_sync.tensor_bucket import (
    FlattenedTensorBucket,
    FlattenedTensorMetadata,
)
128

129
_is_hip = is_hip()
130
_is_npu = is_npu()
131
_is_cpu_amx_available = cpu_has_amx_support()
132

Lianmin Zheng's avatar
Lianmin Zheng committed
133
# Use a small KV cache pool size for tests in CI
134
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
135
136

# Detect stragger ranks in model loading
137
138
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
139
140
logger = logging.getLogger(__name__)

141

142
143
144
145
146
147
148
149
150
151
152
153
154
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
155
class ModelRunner:
156
157
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
158
159
    def __init__(
        self,
160
        model_config: ModelConfig,
161
162
163
164
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
Cheng Wan's avatar
Cheng Wan committed
165
166
        moe_ep_rank: int,
        moe_ep_size: int,
167
168
        pp_rank: int,
        pp_size: int,
169
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
170
        server_args: ServerArgs,
171
        is_draft_worker: bool = False,
172
        req_to_token_pool: Optional[ReqToTokenPool] = None,
173
        token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
174
    ):
175
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
176
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
177
        self.device = server_args.device
178
        self.gpu_id = gpu_id
179
180
181
182

        # Apply the rank zero filter to logger
        if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
            logger.addFilter(RankZeroFilter(tp_rank == 0))
Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Cheng Wan's avatar
Cheng Wan committed
185
186
        self.moe_ep_rank = moe_ep_rank
        self.moe_ep_size = moe_ep_size
187
        self.dp_size = server_args.dp_size
188
189
        self.pp_rank = pp_rank
        self.pp_size = pp_size
190
        self.model_config = model_config
Zhang, Liangang's avatar
Zhang, Liangang committed
191
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
192
        self.server_args = server_args
193
        self.is_draft_worker = is_draft_worker
194
195
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
196
197
198
        self.is_multimodal_chunked_prefill_supported = (
            model_config.is_multimodal_chunked_prefill_supported
        )
199
200
201
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
202
        self.page_size = server_args.page_size
203
204
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
tarinkk's avatar
tarinkk committed
205
        self.is_hybrid = model_config.is_hybrid
Baizhou Zhang's avatar
Baizhou Zhang committed
206
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
207
        self.attention_chunk_size = model_config.attention_chunk_size
Ke Bao's avatar
Ke Bao committed
208

209
210
        self.forward_pass_id = 0

211
        # Model-specific adjustment
212
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
213

214
215
        if server_args.show_time_cost:
            enable_show_time_cost()
216
217

        # Global vars
218
        global_server_args_dict.update(
219
220
221
            {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
            | {
                # TODO it is indeed not a "server args"
222
                "use_mla_backend": self.use_mla_backend,
223
                "speculative_algorithm": self.spec_algorithm,
224
            }
225
226
227
228
            | {
                "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
                "deepep_mode": DeepEPMode(server_args.deepep_mode),
            }
229
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
230

231
        # CPU offload
232
233
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

234
235
236
237
        # Init OpenMP threads binding for CPU
        if self.device == "cpu":
            self.init_threads_binding()

238
        # Get memory before model loading
239
        min_per_gpu_memory = self.init_torch_distributed()
240

241
        # Update deep gemm configure
242
243
        if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
            deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
244

Lianmin Zheng's avatar
Lianmin Zheng committed
245
        # If it is a draft model, tp_group can be different
246
247
        self.initialize(min_per_gpu_memory)

248
249
250
251
        # temporary cached values
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )
252
        self._model_update_group = {}
253

254
255
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
256

257
258
259
260
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

261
262
263
264
265
266
267
268
        if not self.is_draft_worker:
            set_global_expert_location_metadata(
                compute_initial_expert_location_metadata(server_args, self.model_config)
            )
            if self.tp_rank == 0 and get_bool_env_var(
                "SGLANG_LOG_EXPERT_LOCATION_METADATA"
            ):
                logger.info(
269
                    f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
270
271
272
273
274
275
276
277
278
279
                )

            set_global_expert_distribution_recorder(
                ExpertDistributionRecorder.init_new(
                    server_args,
                    get_global_expert_location_metadata(),
                    rank=self.tp_rank,
                )
            )

280
281
282
283
284
        self.eplb_manager = (
            EPLBManager(self)
            if self.server_args.enable_eplb and (not self.is_draft_worker)
            else None
        )
285
        self.expert_location_updater = ExpertLocationUpdater()
286

287
        # Load the model
288
        self.sampler = Sampler()
289
        self.load_model()
290

291
        # Check if the model is using hybrid SWA
Hanming Lu's avatar
Hanming Lu committed
292
293
294
295
296
297
298
299
300
        if (
            not self.server_args.disable_hybrid_swa_memory
            and self.sliding_window_size is not None
            and self.sliding_window_size > 0
        ):
            architectures = self.model_config.hf_config.architectures
            if architectures and not any("Llama4" in arch for arch in architectures):
                self.is_hybrid = self.model_config.is_hybrid = True

301
302
303
304
305
306
307
308
        # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
        # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
        # determine the number of layers.
        model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
        model_num_layers = (
            self.model_config.num_nextn_predict_layers
            if self.is_draft_worker and model_has_mtp_layers
            else self.model_config.num_hidden_layers
309
        )
310
311
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(self.model, "end_layer", model_num_layers)
312
        self.num_effective_layers = self.end_layer - self.start_layer
313
314
315
        assert (not model_has_mtp_layers) or (
            self.num_effective_layers == model_num_layers
        ), "PP is not compatible with MTP models."
316

317
        # Apply torchao quantization
318
319
320
321
322
323
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
324

325
        # Apply torch TP if the model supports it
326
327
328
329
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

330
        # Init lora
331
        if server_args.enable_lora:
332
            self.init_lora_manager()
333
334

        # Init memory pool and attention backends
335
336
        self.init_memory_pool(
            min_per_gpu_memory,
337
            server_args.max_running_requests,
338
339
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
340
341
342
343
344
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
345
            self.cuda_graph_runner = None
346
            self.cuda_graph_mem_usage = 0
Zhang, Liangang's avatar
Zhang, Liangang committed
347
            self.init_attention_backend()
348

James Liu's avatar
James Liu committed
349
350
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
lukec's avatar
lukec committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
            # load draft config
            draft_model_config = ModelConfig.from_server_args(
                server_args,
                model_path=(server_args.speculative_draft_model_path),
                is_draft_model=True,
            )

            try:
                # get the aux layer from draft model config
                eagle_config = getattr(
                    draft_model_config.hf_config, "eagle_config", None
                )
                eagle_aux_hidden_state_layer_ids = eagle_config[
                    "eagle_aux_hidden_state_layer_ids"
                ]
            except:
                # if there is no aux layer, set to None
                eagle_aux_hidden_state_layer_ids = None

            self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
James Liu's avatar
James Liu committed
371

372
373
374
    def model_specific_adjustment(self):
        server_args = self.server_args

375
376
377
        if (
            server_args.attention_backend == "intel_amx"
            and server_args.device == "cpu"
378
            and not _is_cpu_amx_available
379
380
381
382
383
384
        ):
            logger.info(
                "The current platform does not support Intel AMX, will fallback to torch_native backend."
            )
            server_args.attention_backend = "torch_native"

385
386
387
388
389
390
        if server_args.prefill_attention_backend is not None and (
            server_args.prefill_attention_backend
            == server_args.decode_attention_backend
        ):  # override the default attention backend
            server_args.attention_backend = server_args.prefill_attention_backend

391
392
393
394
395
396
397
398
399
400
401
402
403
        if (
            getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
            is not None
        ):
            if server_args.attention_backend is None:
                server_args.attention_backend = "dual_chunk_flash_attn"
                logger.info("Dual chunk attention is turned on by default.")
            elif server_args.attention_backend != "dual_chunk_flash_attn":
                raise ValueError(
                    "Dual chunk attention is enabled, but attention backend is set to "
                    f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
                )

404
        if server_args.attention_backend is None:
405
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
406
407
            Auto select the fastest attention backend.

408
409
410
411
412
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
413
414
                2.2 We will use Flashinfer backend on blackwell.
                2.3 Otherwise, we will use triton backend.
415
416
            """

417
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
418
                # MHA architecture
419
                if (
420
                    is_hopper_with_cuda_12_3()
421
422
423
424
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
425
426
                elif _is_hip:
                    server_args.attention_backend = "aiter"
427
428
                elif _is_npu:
                    server_args.attention_backend = "ascend"
429
430
431
432
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
433
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
434
                # MLA architecture
435
                if is_hopper_with_cuda_12_3():
436
                    server_args.attention_backend = "fa3"
437
438
                elif is_sm100_supported():
                    server_args.attention_backend = "flashinfer"
439
440
441
442
443
444
445
446
447
                elif _is_hip:
                    head_num = self.model_config.get_num_kv_heads(self.tp_size)
                    # TODO current aiter only support head number 16 or 128 head number
                    if (
                        head_num == 128 or head_num == 16
                    ) and self.spec_algorithm.is_none():
                        server_args.attention_backend = "aiter"
                    else:
                        server_args.attention_backend = "triton"
448
449
                elif _is_npu:
                    server_args.attention_backend = "ascend"
450
451
                else:
                    server_args.attention_backend = "triton"
452
            logger.info(
453
                f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default."
454
            )
455
        elif self.use_mla_backend:
456
            if server_args.device != "cpu":
457
                if server_args.attention_backend in [
458
                    "aiter",
459
460
461
462
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
463
                    "cutlass_mla",
464
                    "trtllm_mla",
465
                    "ascend",
466
                ]:
467
468
469
                    logger.info(
                        f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                    )
470
                else:
471
472
473
474
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
475
476
477
478
                if server_args.attention_backend != "intel_amx":
                    raise ValueError(
                        "MLA optimization not supported on CPU except for intel_amx backend."
                    )
479

480
481
482
483
484
485
486
487
488
489
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

490
        if server_args.enable_double_sparsity:
491
492
493
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
494
495
496
497
498
499
500
501
502
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
503
504
505
            if not self.is_multimodal_chunked_prefill_supported:
                server_args.chunked_prefill_size = -1
                logger.info(
506
                    f"Automatically turn off --chunked-prefill-size as it is not supported for "
507
508
                    f"{self.model_config.hf_config.model_type}"
                )
509

510
511
512
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
        elif self.page_size > 1:
513
            logger.info("Disable chunked prefix cache when page size > 1.")
514
515
516
            server_args.disable_chunked_prefix_cache = True

        if not server_args.disable_chunked_prefix_cache:
517
            logger.info("Chunked prefix cache is turned on.")
518

kk's avatar
kk committed
519
520
521
522
        if server_args.attention_backend == "aiter":
            if self.model_config.context_len > 8192:
                self.mem_fraction_static *= 0.85

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        if (
            server_args.enable_hierarchical_cache
            and server_args.hicache_io_backend == "kernel"
        ):
            # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
            if server_args.decode_attention_backend is None:
                if not self.use_mla_backend:
                    server_args.decode_attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
                else:
                    server_args.decode_attention_backend = (
                        "flashinfer" if is_sm100_supported() else "triton"
                    )
            elif server_args.decode_attention_backend == "fa3":
                server_args.hicache_io_backend = "direct"
                logger.warning(
                    "FlashAttention3 decode backend is not compatible with hierarchical cache. "
                    f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
                )

544
    def init_torch_distributed(self):
545
        logger.info("Init torch distributed begin.")
546

547
548
549
550
551
552
553
554
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
555
556
        if self.device == "cuda":
            backend = "nccl"
557
        elif self.device == "xpu":
558
            backend = "xccl"
559
560
        elif self.device == "hpu":
            backend = "hccl"
561
562
        elif self.device == "cpu":
            backend = "gloo"
563
564
        elif self.device == "npu":
            backend = "hccl"
565

566
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
567
        if not self.server_args.enable_p2p_check:
568
569
            monkey_patch_p2p_access_check()

570
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
571
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
572
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
573
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
574
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
575
        set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
576
577

        if not self.is_draft_worker:
578
579
580
581
            if self.device == "cpu":
                if _is_cpu_amx_available:
                    # Bind OpenMP threads to CPU cores
                    torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
582
583
584
585

                    # Set local size to hint SGLang to use shared memory based AllReduce
                    os.environ["LOCAL_SIZE"] = str(self.tp_size)
                    torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
586
587
                else:
                    logger.warning(
588
                        "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
589
590
                    )

Mick's avatar
Mick committed
591
            # Only initialize the distributed environment on the target model worker.
592
593
            init_distributed_environment(
                backend=backend,
594
595
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
596
597
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
598
                timeout=self.server_args.dist_timeout,
599
            )
600
601
602
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
Cheng Wan's avatar
Cheng Wan committed
603
                expert_model_parallel_size=self.moe_ep_size,
604
                duplicate_tp_group=self.server_args.enable_pdmux,
605
            )
606
607
608
609
610
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
611
                moe_dense_tp_size=self.server_args.moe_dense_tp_size,
612
                pp_size=self.server_args.pp_size,
613
            )
614

615
        min_per_gpu_memory = get_available_gpu_memory(
616
617
618
619
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
620
        )
621
        self.tp_group = get_tp_group()
622
        self.attention_tp_group = get_attention_tp_group()
623

624
        # Check memory for tensor parallelism
625
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
626
        if self.tp_size > 1 and not self.is_draft_worker:
627
            if min_per_gpu_memory < local_gpu_memory * 0.9:
628
629
630
631
632
633
634
635
636
637
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
638

639
640
641
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
642
        return min_per_gpu_memory
643

Lianmin Zheng's avatar
Lianmin Zheng committed
644
    def load_model(self):
645
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
646
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
647
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
648
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
649
650

        # This can reduce thread conflicts and speed up weight loading.
651
652
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
653
654
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
655
656
657
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
658
                self.server_args.dtype = "float16"
659
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
660
661
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
662

663
664
        set_cuda_arch()

665
        # Prepare the model config
666
667
668
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
669
            model_loader_extra_config=self.server_args.model_loader_extra_config,
670
        )
671
672
673
674
        if self.device == "cpu":
            self.model_config = adjust_config_with_unaligned_cpu_tp(
                self.model_config, self.load_config, self.tp_size
            )
675
676
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
677
678

        # Load the model
679
680
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
681
682
        monkey_patch_isinstance_for_vllm_base_layer()

683
        with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
684
685
686
687
688
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
689
        monkey_patch_vllm_parallel_state(reverse=True)
690
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
691

bjmsong's avatar
bjmsong committed
692
693
694
695
696
697
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
698
699
700
701
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
702
703
704
705
706
707
708
709
710
711
712
713
714
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

715
        # Parse other args
Hanming Lu's avatar
Hanming Lu committed
716
717
718
719
720
        self.sliding_window_size = None
        if hasattr(self.model, "get_attention_sliding_window_size"):
            self.sliding_window_size = self.model.get_attention_sliding_window_size()
        elif self.model_config.attention_chunk_size is not None:
            self.sliding_window_size = self.model_config.attention_chunk_size
721
            logger.info(
Hanming Lu's avatar
Hanming Lu committed
722
723
724
                f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
            )

725
        self.dtype = self.model_config.dtype
726

727
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
728
        self.weight_load_mem_usage = before_avail_memory - after_avail_memory
729
        logger.info(
730
            f"Load weight end. "
731
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
732
            f"dtype={self.dtype}, "
733
            f"avail mem={after_avail_memory:.2f} GB, "
734
            f"mem usage={self.weight_load_mem_usage:.2f} GB."
735
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
736

737
738
739
740
741
742
743
744
745
746
747
748
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

749
    def update_expert_location(
750
751
752
        self,
        new_expert_location_metadata: ExpertLocationMetadata,
        update_layer_ids: List[int],
753
    ):
754
        self.expert_location_updater.update(
755
756
            self.model.routed_experts_weights_of_layer,
            new_expert_location_metadata,
757
            update_layer_ids=update_layer_ids,
758
759
760
761
            nnodes=self.server_args.nnodes,
            rank=self.tp_rank,
        )

762
763
764
765
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
766
        logger.info(
Chayenne's avatar
Chayenne committed
767
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
768
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
769
770
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
771
        target_device = torch.device(self.device)
772
        self.model_config.model_path = model_path
773
774
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
775
        # Only support DefaultModelLoader for now
776
777
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
778
779
            message = f"Failed to get model loader: {loader}."
            return False, message
780
781
782

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
783
                DefaultModelLoader.Source.init_new(config, self.model)
784
785
786
787
            )
            return iter

        def model_load_weights(model, iter):
788
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
789
790
            return model

791
        with set_default_torch_dtype(self.model_config.dtype):
792
            try:
793
                iter = get_weight_iter(self.model_config)
794
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
795
                message = f"Failed to get weights iterator: {e}."
796
797
798
799
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
800
801
802
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
803
804
                del iter
                gc.collect()
805
                iter = get_weight_iter(self.model_config)
806
807
808
809
810
811
812
813
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

814
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
815
        return True, "Succeeded to update model weights."
816

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
845
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
846
847
848
        )

        try:
849
            self._model_update_group[group_name] = init_custom_process_group(
850
851
852
853
854
855
856
857
858
859
860
861
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

862
    def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
863
864
865
866
867
868
869
870
871
872
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """

873
874
875
876
        assert group_name in self._model_update_group, (
            f"Group {group_name} not in {list(self._model_update_group.keys())}. "
            "Please call `init_weights_update_group` first."
        )
877
878

        try:
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
            weights = []
            handles = []
            for name, dtype, shape in zip(names, dtypes, shapes):
                target_dtype = (
                    dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
                )
                weight = torch.empty(shape, dtype=target_dtype, device=self.device)
                handles.append(
                    torch.distributed.broadcast(
                        weight,
                        src=0,
                        group=self._model_update_group[group_name],
                        async_op=True,
                    )
                )
                weights.append((name, weight))
            for handle in handles:
                handle.wait()

            self.model.load_weights(weights)
            return True, f"Succeeded to update parameter online."
900
901
902
903
904
905
906
907
908
909

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

910
911
912
913
914
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
915
        monkey_patch_torch_reductions()
916
917
918
919
920
921
        if load_format == "flattened_bucket":
            # Handle flattened bucket format
            return self._update_weights_from_flattened_bucket(
                flattened_tensor_bucket_dict=named_tensors
            )

922
923
924
        # We need to get device after patch otherwise the device would be wrong
        infered_device = torch.cuda.current_device()

925
        named_tensors = [
926
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
927
928
929
930
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
931
932
933
        elif load_format in self.server_args.custom_weight_loader:
            custom_loader = dynamic_import(load_format)
            custom_loader(self.model, named_tensors)
934
935
936
937
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
938
        return True, "Success"
939

940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
    def _update_weights_from_flattened_bucket(
        self,
        flattened_tensor_bucket_dict,
    ):
        """Handle flattened bucket format for weight updates"""
        flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
        metadata = flattened_tensor_bucket_dict["metadata"]

        # Convert metadata dict to our format
        converted_metadata = []
        for meta in metadata:
            converted_meta = FlattenedTensorMetadata(
                name=meta.name,
                shape=meta.shape,
                dtype=meta.dtype,
                start_idx=meta.start_idx,
                end_idx=meta.end_idx,
                numel=meta.numel,
            )
            converted_metadata.append(converted_meta)

        # Create bucket and reconstruct tensors
        bucket = FlattenedTensorBucket(
            flattened_tensor=flattened_tensor, metadata=converted_metadata
        )
        reconstructed_tensors = bucket.reconstruct_tensors()

        # Load the reconstructed tensors using the standard method
        self.model.load_weights(reconstructed_tensors)

        return True, "Success"

972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

989
990
991
992
993
994
995
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
996
            lora_backend=self.server_args.lora_backend,
997
998
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
999
1000
            max_lora_rank=self.server_args.max_lora_rank,
            target_modules=self.server_args.lora_target_modules,
1001
            lora_paths=self.server_args.lora_paths,
1002
        )
1003

1004
    def load_lora_adapter(self, lora_ref: LoRARef):
1005
1006
1007
        """Load a new lora adapter from disk or huggingface."""

        logger.info(
1008
            f"LoRA adapter loading starts: {lora_ref}. "
1009
1010
1011
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1012
        result = self.lora_manager.load_lora_adapter(lora_ref)
1013
1014

        logger.info(
1015
            f"LoRA adapter loading completes: {lora_ref}. "
1016
1017
1018
1019
1020
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result

1021
    def unload_lora_adapter(self, lora_ref: LoRARef):
1022
1023
1024
        """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""

        logger.info(
1025
            f"LoRA adapter unloading starts: {lora_ref}. "
1026
1027
1028
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

1029
        result = self.lora_manager.unload_lora_adapter(lora_ref)
1030
1031

        logger.info(
1032
            f"LoRA adapter unloading completes: {lora_ref}. "
1033
1034
1035
1036
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result
1037

1038
    def profile_max_num_token(self, total_gpu_memory: int):
1039
        available_gpu_memory = get_available_gpu_memory(
1040
1041
1042
1043
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
1044
        )
1045
1046
1047
1048
1049
        if self.is_draft_worker:
            num_layers = getattr(
                self.model_config.hf_config,
                "num_nextn_predict_layers",
                self.num_effective_layers,
1050
            )
1051
1052
1053
        else:
            num_layers = self.num_effective_layers
        if self.use_mla_backend:
1054
1055
            # FIXME: pipeline parallelism is not compatible with mla backend
            assert self.pp_size == 1
1056
1057
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
1058
                * num_layers
1059
                * torch._utils._element_size(self.kv_cache_dtype)
1060
1061
1062
            )
        else:
            cell_size = (
1063
                self.model_config.get_num_kv_heads(get_attention_tp_size())
1064
                * self.model_config.head_dim
1065
                * num_layers
1066
                * 2
1067
                * torch._utils._element_size(self.kv_cache_dtype)
1068
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1069
1070
1071
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
1072
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1073
1074
        return max_num_token

tarinkk's avatar
tarinkk committed
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
    def set_num_token_hybrid(self):
        if (
            "Llama4ForConditionalGeneration"
            in self.model_config.hf_config.architectures
        ):
            temp_ratio = (
                (1 - self.is_hybrid)
                + self.is_hybrid
                * self.attention_chunk_size
                / self.model_config.context_len
            )
            self.swa_max_total_num_tokens = (
                4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.full_max_total_num_tokens = (
                4 * self.max_total_num_tokens
                - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.swa_max_total_num_tokens = int(
                self.swa_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.full_max_total_num_tokens = int(
                self.full_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens
        else:
Hanming Lu's avatar
Hanming Lu committed
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
            assert self.sliding_window_size is not None and self.sliding_window_size > 0
            full_attention_layer_ids = []
            swa_attention_layer_ids = []

            try:
                layers = self.model.model.layers
            except:
                try:
                    layers = self.model.language_model.model.layers
                except:
1115
1116
1117
1118
1119
                    try:
                        layers = self.model.language_model.layers
                    except:
                        self.is_hybrid = False
                        return
Hanming Lu's avatar
Hanming Lu committed
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154

            for layer in layers:
                if (
                    layer.self_attn.attn.sliding_window_size is None
                    or layer.self_attn.attn.sliding_window_size == -1
                ):
                    full_attention_layer_ids.append(layer.layer_id)
                else:
                    swa_attention_layer_ids.append(layer.layer_id)
            self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
            self.model_config.full_attention_layer_ids = full_attention_layer_ids

            # Algorithm:
            # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
            # - Find total # of tokens available across layers.
            # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
            total_tokens = (
                self.max_total_num_tokens * self.model_config.num_hidden_layers
            )
            full_layers_num = len(full_attention_layer_ids)
            swa_layers_num = len(swa_attention_layer_ids)
            swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio

            # Solve the equations:
            # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
            # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
            denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
            self.full_max_total_num_tokens = int(total_tokens / denominator)
            self.swa_max_total_num_tokens = int(
                self.full_max_total_num_tokens * swa_full_tokens_ratio
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens

            logger.info(
                f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
tarinkk's avatar
tarinkk committed
1155
1156
            )

1157
    def init_memory_pool(
1158
1159
        self,
        total_gpu_memory: int,
1160
1161
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
1162
    ):
1163
1164
1165
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1166
            if _is_hip:  # Using natively supported format
HAI's avatar
HAI committed
1167
1168
1169
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
1170
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
1171
1172
1173
            if _is_hip:  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e4m3fnuz
            else:
bjmsong's avatar
bjmsong committed
1174
                self.kv_cache_dtype = torch.float8_e4m3fn
1175
1176
1177
1178
1179
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

1180
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

1193
1194
1195
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

1196
1197
1198
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1199
                max_num_reqs = self.server_args.max_num_reqs
1200
            else:
1201
1202
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
1203
1204
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
1205
1206
1207
1208
1209
1210
1211
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
1212
1213
                    + 100
                )
1214
1215
1216
1217
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
1218

1219
1220
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
1221
                logging.warning(
1222
1223
1224
1225
1226
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
1227

1228
1229
1230
1231
1232
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )
tarinkk's avatar
tarinkk committed
1233
1234
1235
1236
        # create token size for hybrid cache
        if self.is_hybrid:
            self.set_num_token_hybrid()

1237
        if self.max_total_num_tokens <= 0:
1238
            raise RuntimeError(
1239
                "Not enough memory. Please try to increase --mem-fraction-static."
1240
            )
1241

1242
        if self.req_to_token_pool is None:
Byron Hsu's avatar
Byron Hsu committed
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
            if self.server_args.disaggregation_mode == "decode":
                from sglang.srt.disaggregation.decode import DecodeReqToTokenPool

                # subscribe memory for pre-allocated requests
                # if max_num_reqs <= 32, we pre-allocate 2x requests
                pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
                self.req_to_token_pool = DecodeReqToTokenPool(
                    size=max_num_reqs,
                    max_context_len=self.model_config.context_len + 4,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    pre_alloc_size=pre_alloc_size,
                )
            else:
                self.req_to_token_pool = ReqToTokenPool(
                    size=max_num_reqs,
                    max_context_len=self.model_config.context_len + 4,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
1263
1264
1265
1266
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

Lianmin Zheng's avatar
Lianmin Zheng committed
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
        if self.server_args.attention_backend == "ascend":
            if self.use_mla_backend:
                self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    kv_lora_rank=self.model_config.kv_lora_rank,
                    qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                    layer_num=self.num_effective_layers,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
                )
            else:
                self.token_to_kv_pool = AscendTokenToKVPool(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    layer_num=self.model_config.num_hidden_layers,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
1294
        elif self.use_mla_backend:
1295
1296
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1297
                page_size=self.page_size,
1298
                dtype=self.kv_cache_dtype,
1299
1300
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1301
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
1302
                device=self.device,
1303
                enable_memory_saver=self.server_args.enable_memory_saver,
1304
1305
                start_layer=self.start_layer,
                end_layer=self.end_layer,
1306
            )
Shuo Yang's avatar
Shuo Yang committed
1307
1308
1309
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1310
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
1311
                dtype=self.kv_cache_dtype,
1312
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
1313
                head_dim=self.model_config.head_dim,
1314
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
1315
1316
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
1317
                enable_memory_saver=self.server_args.enable_memory_saver,
1318
1319
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
1320
            )
1321
        else:
tarinkk's avatar
tarinkk committed
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
            if self.is_hybrid:
                self.token_to_kv_pool = SWAKVPool(
                    size=self.full_max_total_num_tokens,
                    size_swa=self.swa_max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
                    full_attention_layer_ids=self.model_config.full_attention_layer_ids,
                    enable_kvcache_transpose=False,
                    device=self.device,
                )
            else:
                self.token_to_kv_pool = MHATokenToKVPool(
Lianmin Zheng's avatar
Lianmin Zheng committed
1338
                    self.max_total_num_tokens,
tarinkk's avatar
tarinkk committed
1339
                    page_size=self.page_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
1340
                    dtype=self.kv_cache_dtype,
tarinkk's avatar
tarinkk committed
1341
1342
1343
1344
1345
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    layer_num=self.num_effective_layers,
Lianmin Zheng's avatar
Lianmin Zheng committed
1346
                    device=self.device,
tarinkk's avatar
tarinkk committed
1347
1348
1349
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
Lianmin Zheng's avatar
Lianmin Zheng committed
1350
                )
tarinkk's avatar
tarinkk committed
1351

Lianmin Zheng's avatar
Lianmin Zheng committed
1352
        need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
tarinkk's avatar
tarinkk committed
1353
1354
1355
1356
1357
1358
1359
1360
1361
        if self.token_to_kv_pool_allocator is None:
            if self.page_size == 1:
                if self.is_hybrid:
                    self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
                        self.full_max_total_num_tokens,
                        self.swa_max_total_num_tokens,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
1362
                        need_sort=need_sort,
tarinkk's avatar
tarinkk committed
1363
1364
1365
1366
1367
1368
1369
                    )
                else:
                    self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                        self.max_total_num_tokens,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
1370
                        need_sort=need_sort,
tarinkk's avatar
tarinkk committed
1371
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1372
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1373
1374
                if not _is_npu:
                    self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1375
1376
1377
1378
1379
                        self.max_total_num_tokens,
                        page_size=self.page_size,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
1380
                        need_sort=need_sort,
1381
1382
                    )
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1383
                    self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1384
1385
1386
1387
1388
                        self.max_total_num_tokens,
                        page_size=self.page_size,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
1389
                        need_sort=need_sort,
1390
                    )
1391
1392
1393
        else:
            assert self.is_draft_worker

1394
        logger.info(
1395
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
1396
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
1397
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1398

Lianmin Zheng's avatar
Lianmin Zheng committed
1399
1400
1401
1402
1403
1404
1405
1406
1407
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

1408
1409
    def init_attention_backend(self):
        """Init attention kernel backend."""
1410
        if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1411
1412
1413
1414
1415
            self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
        else:
            self.attn_backend = self._get_attention_backend()

    def _get_attention_backend(self):
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
        """Init attention kernel backend."""
        self.decode_attention_backend_str = (
            self.server_args.decode_attention_backend
            if self.server_args.decode_attention_backend
            else self.server_args.attention_backend
        )
        self.prefill_attention_backend_str = (
            self.server_args.prefill_attention_backend
            if self.server_args.prefill_attention_backend
            else self.server_args.attention_backend
        )
        if self.decode_attention_backend_str != self.prefill_attention_backend_str:
            assert (
                self.server_args.speculative_algorithm is None
            ), "Currently HybridAttentionBackend does not support speculative decoding."
            from sglang.srt.layers.attention.hybrid_attn_backend import (
                HybridAttnBackend,
            )

            attn_backend = HybridAttnBackend(
                decode_backend=self._get_attention_backend_from_str(
                    self.decode_attention_backend_str
                ),
                prefill_backend=self._get_attention_backend_from_str(
                    self.prefill_attention_backend_str
                ),
            )
            logger.info(
                f"Using hybrid attention backend for decode and prefill: "
                f"decode_backend={self.decode_attention_backend_str}, "
                f"prefill_backend={self.prefill_attention_backend_str}."
            )
            logger.warning(
                f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
                f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
            )
        else:
            attn_backend = self._get_attention_backend_from_str(
                self.server_args.attention_backend
            )

        global_server_args_dict.update(
            {
                "decode_attention_backend": self.decode_attention_backend_str,
                "prefill_attention_backend": self.prefill_attention_backend_str,
            }
        )
        return attn_backend

    def _get_attention_backend_from_str(self, backend_str: str):
        if backend_str == "flashinfer":
1467
1468
1469
1470
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
1471

1472
1473
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
1474
1475
1476
1477
1478
                    if (
                        not hasattr(self, "plan_stream_for_flashinfer")
                        or not self.plan_stream_for_flashinfer
                    ):
                        self.plan_stream_for_flashinfer = torch.cuda.Stream()
1479
                return FlashInferAttnBackend(self)
1480
1481
1482
1483
1484
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

1485
                return FlashInferMLAAttnBackend(self)
1486
        elif backend_str == "aiter":
1487
1488
            from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend

1489
            return AiterAttnBackend(self)
1490
1491
1492
1493
        elif self.server_args.attention_backend == "wave":
            from sglang.srt.layers.attention.wave_backend import WaveAttnBackend

            return WaveAttnBackend(self)
1494
        elif backend_str == "ascend":
1495
1496
1497
            from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend

            return AscendAttnBackend(self)
1498
        elif backend_str == "triton":
1499
1500
1501
1502
1503
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
1504
1505
1506
1507
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

1508
                return DoubleSparseAttnBackend(self)
1509
            else:
1510
1511
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

1512
                return TritonAttnBackend(self)
1513
        elif backend_str == "torch_native":
1514
1515
1516
1517
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

1518
            return TorchNativeAttnBackend(self)
1519
        elif backend_str == "flashmla":
lukec's avatar
lukec committed
1520
1521
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

1522
            return FlashMLABackend(self)
1523
        elif backend_str == "fa3":
1524
1525
1526
1527
            assert (
                torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
            ) or torch.cuda.get_device_capability()[0] == 9, (
                "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1528
1529
1530
1531
1532
1533
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

1534
            return FlashAttentionBackend(self)
1535
        elif backend_str == "cutlass_mla":
1536
1537
1538
1539
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

1540
            return CutlassMLABackend(self)
1541
        elif backend_str == "trtllm_mla":
1542
1543
1544
1545
1546
            if not self.use_mla_backend:
                raise ValueError("trtllm_mla backend can only be used with MLA models.")
            from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend

            return TRTLLMMLABackend(self)
1547
        elif backend_str == "trtllm_mha":
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
            if self.use_mla_backend:
                raise ValueError(
                    "trtllm_mha backend can only be used with non-MLA models."
                )
            from sglang.srt.layers.attention.trtllm_mha_backend import (
                TRTLLMHAAttnBackend,
            )

            return TRTLLMHAAttnBackend(self)

1558
        elif backend_str == "intel_amx":
1559
1560
1561
1562
1563
1564
            from sglang.srt.layers.attention.intel_amx_backend import (
                IntelAMXAttnBackend,
            )

            logger.info(f"Intel AMX attention backend is enabled.")
            return IntelAMXAttnBackend(self)
1565
1566
1567
1568
1569
1570
        elif self.server_args.attention_backend == "dual_chunk_flash_attn":
            from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
                DualChunkFlashAttentionBackend,
            )

            return DualChunkFlashAttentionBackend(self)
1571
        else:
1572
            raise ValueError(f"Invalid attention backend: {backend_str}")
1573

Shuo Yang's avatar
Shuo Yang committed
1574
1575
1576
1577
1578
1579
1580
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1581
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1582
1583
1584
1585
1586
1587
1588
1589
1590
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1591
    def init_cuda_graphs(self):
1592
        """Capture cuda graphs."""
1593
        self.cuda_graph_runner = None
1594
        self.cuda_graph_mem_usage = 0
1595

1596
        if not self.is_generation:
1597
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1598
1599
            return

1600
1601
        if self.server_args.disable_cuda_graph:
            return
1602

1603
        tic = time.perf_counter()
1604
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1605
        logger.info(
1606
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1607
        )
1608
        self.cuda_graph_runner = CudaGraphRunner(self)
1609
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1610
        self.cuda_graph_mem_usage = before_mem - after_mem
1611
        logger.info(
1612
            f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1613
            f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1614
        )
1615

1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
    def init_threads_binding(self):
        omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
        if omp_cpuids == "all":
            cpu_ids_by_node = get_cpu_ids_by_node()
            n_numa_node = len(cpu_ids_by_node)

            assert self.tp_size <= n_numa_node, (
                f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
                f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
                f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
                f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
                f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
                f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
                f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
                f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
            )
            if self.tp_size < n_numa_node:
                logger.warning(
                    f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
                )
            self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
        else:
            self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]

1640
    def apply_torch_tp(self):
1641
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1642
1643
1644
1645
1646
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1647
    def forward_decode(
Cheng Wan's avatar
Cheng Wan committed
1648
1649
1650
1651
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
1652
    ) -> LogitsProcessorOutput:
Cheng Wan's avatar
Cheng Wan committed
1653
1654
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)
1655
1656
1657
1658
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1659
        return self.model.forward(
1660
1661
1662
1663
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
1664
1665
        )

1666
    def forward_extend(
1667
1668
1669
1670
1671
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
1672
1673
1674
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1688

1689
1690
1691
1692
1693
1694
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
1695
        return self.model.forward(
1696
1697
1698
1699
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
1700
1701
        )

1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
    def forward_split_prefill(
        self,
        forward_batch: ForwardBatch,
        reinit_attn_backend: bool = False,
        forward_count: int = 1,
    ) -> LogitsProcessorOutput:
        if forward_batch.split_index == 0 or reinit_attn_backend:
            self.attn_backend.init_forward_metadata(forward_batch)
        next_split_index = min(
            forward_batch.split_index + forward_count,
            self.model_config.num_hidden_layers,
        )
        ret = self.model.forward_split_prefill(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            (forward_batch.split_index, next_split_index),
        )
        forward_batch.split_index = next_split_index
        return ret

1723
    def forward(
1724
1725
1726
1727
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
1728
1729
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
1730
1731
1732
1733
1734
1735
1736
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
        self.forward_pass_id += 1

        with get_global_expert_distribution_recorder().with_forward_pass(
            self.forward_pass_id,
            forward_batch,
        ):
1737
            output = self._forward_raw(
1738
1739
1740
1741
1742
                forward_batch,
                skip_attn_backend_init,
                pp_proxy_tensors,
                reinit_attn_backend,
                split_forward_count,
1743
1744
            )

1745
        if self.eplb_manager is not None:
1746
            self.eplb_manager.on_forward_pass_end()
1747
1748
1749

        return output

1750
1751
1752
1753
1754
    def _forward_raw(
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool,
        pp_proxy_tensors: Optional[PPProxyTensors],
1755
1756
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
1757
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1758
        can_run_cuda_graph = bool(
1759
1760
1761
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
1762
1763
        )
        if can_run_cuda_graph:
1764
            ret = self.cuda_graph_runner.replay(
1765
1766
1767
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1768
            )
Cheng Wan's avatar
Cheng Wan committed
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
            return ret, can_run_cuda_graph

        # For MLP sync
        if forward_batch.global_num_tokens_cpu is not None:
            forward_batch.prepare_mlp_sync_batch(self)

        if forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
            )
1781
        elif forward_batch.forward_mode.is_extend():
1782
            ret = self.forward_extend(
1783
1784
1785
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1786
            )
1787
1788
1789
1790
1791
1792
        elif forward_batch.forward_mode.is_split_prefill():
            ret = self.forward_split_prefill(
                forward_batch,
                reinit_attn_backend=reinit_attn_backend,
                forward_count=split_forward_count,
            )
Ke Bao's avatar
Ke Bao committed
1793
        elif forward_batch.forward_mode.is_idle():
1794
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
1795
        else:
1796
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1797

Cheng Wan's avatar
Cheng Wan committed
1798
1799
1800
        if forward_batch.global_num_tokens_cpu is not None:
            forward_batch.post_forward_mlp_sync_batch(ret)

1801
1802
        return ret, can_run_cuda_graph

1803
1804
1805
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1806
        # Apply logit bias
1807
1808
1809
1810
1811
1812
1813
1814
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1815
1816
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
1837

1838
1839
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1840
1841
1842
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1843
            forward_batch.sampling_info,
1844
1845
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1846
            forward_batch.token_ids_logprobs,
1847
        )
1848
1849
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1850
1851
1852
1853
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
1854
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
1855
1856
        if rope_scaling is None:
            return False
1857
1858
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1859

1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1876
1877
1878
1879
1880
1881
1882

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


1883
def _unwrap_tensor(tensor, tp_rank, device):
1884
    if isinstance(tensor, LocalSerializedTensor):
1885
        tensor = tensor.get(tp_rank)
1886
    return tensor.to(device)
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])