model_runner.py 68.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
18
import inspect
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
import time
23
24
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.distributed as dist
28
29
30
31

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
32
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
33
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
34
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
35
    get_tp_group,
36
    get_world_group,
zhyncs's avatar
zhyncs committed
37
38
    init_distributed_environment,
    initialize_model_parallel,
39
    set_custom_all_reduce,
40
    set_mscclpp_all_reduce,
zhyncs's avatar
zhyncs committed
41
)
42
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
fzyzcjy's avatar
fzyzcjy committed
43
44
45
46
47
48
49
50
51
52
53
54
55
from sglang.srt.eplb.eplb_manager import EPLBManager
from sglang.srt.eplb.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
    set_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_location import (
    ExpertLocationMetadata,
    compute_initial_expert_location_metadata,
    get_global_expert_location_metadata,
    set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
56
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
57
58
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
59
    get_attention_tp_size,
60
61
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
62
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
63
64
65
from sglang.srt.layers.quantization import (
    deep_gemm_wrapper,
    monkey_patch_isinstance_for_vllm_base_layer,
66
)
67
from sglang.srt.layers.sampler import Sampler
68
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
69
from sglang.srt.layers.utils import is_sm100_supported
70
from sglang.srt.lora.lora_manager import LoRAManager
71
72
73
74
from sglang.srt.managers.schedule_batch import (
    GLOBAL_SERVER_ARGS_KEYS,
    global_server_args_dict,
)
75
from sglang.srt.mem_cache.allocator import (
76
    AscendPagedTokenToKVPoolAllocator,
77
78
    BaseTokenToKVPoolAllocator,
    PagedTokenToKVPoolAllocator,
tarinkk's avatar
tarinkk committed
79
    SWATokenToKVPoolAllocator,
80
81
    TokenToKVPoolAllocator,
)
82
from sglang.srt.mem_cache.memory_pool import (
83
84
    AscendMLAPagedTokenToKVPool,
    AscendTokenToKVPool,
Shuo Yang's avatar
Shuo Yang committed
85
    DoubleSparseTokenToKVPool,
86
87
88
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
tarinkk's avatar
tarinkk committed
89
    SWAKVPool,
90
)
Yineng Zhang's avatar
Yineng Zhang committed
91
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
92
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
93
from sglang.srt.model_loader import get_model
94
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
Lianmin Zheng's avatar
Lianmin Zheng committed
95
from sglang.srt.model_loader.utils import set_default_torch_dtype
96
from sglang.srt.model_loader.weight_utils import default_weight_loader
97
from sglang.srt.patch_torch import monkey_patch_torch_reductions
98
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
99
from sglang.srt.server_args import ServerArgs
100
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
101
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
102
from sglang.srt.utils import (
103
    MultiprocessingSerializer,
104
    cpu_has_amx_support,
105
    dynamic_import,
106
    enable_show_time_cost,
107
    get_available_gpu_memory,
108
    get_bool_env_var,
109
    get_cpu_ids_by_node,
110
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
111
    is_cuda,
112
    is_fa3_default_architecture,
113
    is_flashinfer_available,
HAI's avatar
HAI committed
114
    is_hip,
115
    is_hopper_with_cuda_12_3,
116
    is_no_spec_infer_or_topk_one,
117
    is_npu,
118
    monkey_patch_p2p_access_check,
119
    monkey_patch_vllm_gguf_config,
120
    set_cpu_offload_max_bytes,
121
    set_cuda_arch,
122
)
123

124
_is_hip = is_hip()
125
_is_npu = is_npu()
126
_is_cpu_amx_available = cpu_has_amx_support()
127

Lianmin Zheng's avatar
Lianmin Zheng committed
128
# Use a small KV cache pool size for tests in CI
129
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
130
131

# Detect stragger ranks in model loading
132
133
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
134
135
logger = logging.getLogger(__name__)

136

137
138
139
140
141
142
143
144
145
146
147
148
149
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
150
class ModelRunner:
151
152
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
153
154
    def __init__(
        self,
155
        model_config: ModelConfig,
156
157
158
159
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
160
161
        pp_rank: int,
        pp_size: int,
162
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
163
        server_args: ServerArgs,
164
        is_draft_worker: bool = False,
165
        req_to_token_pool: Optional[ReqToTokenPool] = None,
166
        token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
167
    ):
168
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
169
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
170
        self.device = server_args.device
171
        self.gpu_id = gpu_id
172
173
174
175

        # Apply the rank zero filter to logger
        if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
            logger.addFilter(RankZeroFilter(tp_rank == 0))
Lianmin Zheng's avatar
Lianmin Zheng committed
176
177
        self.tp_rank = tp_rank
        self.tp_size = tp_size
178
        self.dp_size = server_args.dp_size
179
180
        self.pp_rank = pp_rank
        self.pp_size = pp_size
181
        self.model_config = model_config
Zhang, Liangang's avatar
Zhang, Liangang committed
182
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
183
        self.server_args = server_args
184
        self.is_draft_worker = is_draft_worker
185
186
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
187
188
189
        self.is_multimodal_chunked_prefill_supported = (
            model_config.is_multimodal_chunked_prefill_supported
        )
190
191
192
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
193
        self.page_size = server_args.page_size
194
195
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
tarinkk's avatar
tarinkk committed
196
        self.is_hybrid = model_config.is_hybrid
Baizhou Zhang's avatar
Baizhou Zhang committed
197
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
198
        self.attention_chunk_size = model_config.attention_chunk_size
Ke Bao's avatar
Ke Bao committed
199

200
201
        self.forward_pass_id = 0

202
        # Model-specific adjustment
203
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
204

205
206
        if server_args.show_time_cost:
            enable_show_time_cost()
207
208

        # Global vars
209
        global_server_args_dict.update(
210
211
212
            {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
            | {
                # TODO it is indeed not a "server args"
213
                "use_mla_backend": self.use_mla_backend,
214
                "speculative_algorithm": self.spec_algorithm,
215
216
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
217

218
        # CPU offload
219
220
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

221
222
223
224
        # Init OpenMP threads binding for CPU
        if self.device == "cpu":
            self.init_threads_binding()

225
        # Get memory before model loading
226
        min_per_gpu_memory = self.init_torch_distributed()
227

228
        # Update deep gemm configure
229
230
        if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
            deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
231

Lianmin Zheng's avatar
Lianmin Zheng committed
232
        # If it is a draft model, tp_group can be different
233
234
        self.initialize(min_per_gpu_memory)

235
236
237
238
        # temporary cached values
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )
239
        self._model_update_group = {}
240

241
242
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
243

244
245
246
247
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

248
249
250
251
252
253
254
255
        if not self.is_draft_worker:
            set_global_expert_location_metadata(
                compute_initial_expert_location_metadata(server_args, self.model_config)
            )
            if self.tp_rank == 0 and get_bool_env_var(
                "SGLANG_LOG_EXPERT_LOCATION_METADATA"
            ):
                logger.info(
256
                    f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
257
258
259
260
261
262
263
264
265
266
                )

            set_global_expert_distribution_recorder(
                ExpertDistributionRecorder.init_new(
                    server_args,
                    get_global_expert_location_metadata(),
                    rank=self.tp_rank,
                )
            )

267
268
269
270
271
        self.eplb_manager = (
            EPLBManager(self)
            if self.server_args.enable_eplb and (not self.is_draft_worker)
            else None
        )
272
        self.expert_location_updater = ExpertLocationUpdater()
273

274
        # Load the model
275
        self.sampler = Sampler()
276
        self.load_model()
277

Hanming Lu's avatar
Hanming Lu committed
278
279
280
281
282
283
284
285
286
        if (
            not self.server_args.disable_hybrid_swa_memory
            and self.sliding_window_size is not None
            and self.sliding_window_size > 0
        ):
            architectures = self.model_config.hf_config.architectures
            if architectures and not any("Llama4" in arch for arch in architectures):
                self.is_hybrid = self.model_config.is_hybrid = True

287
288
289
290
291
292
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(
            self.model, "end_layer", self.model_config.num_hidden_layers
        )
        self.num_effective_layers = self.end_layer - self.start_layer

293
        # Apply torchao quantization
294
295
296
297
298
299
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
300

301
        # Apply torch TP if the model supports it
302
303
304
305
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

306
        # Init lora
307
308
309
310
        # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
        # a new server arg `enable_lora` to control whether to init LoRA manager to be more
        # explicit, as it is perfectly valid to start a server with an empty lora_paths and
        # load LoRA adapters dynamically later.
311
312
        if server_args.lora_paths is not None:
            self.init_lora_manager()
313
314

        # Init memory pool and attention backends
315
316
        self.init_memory_pool(
            min_per_gpu_memory,
317
            server_args.max_running_requests,
318
319
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
320
321
322
323
324
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
325
            self.cuda_graph_runner = None
326
            self.cuda_graph_mem_usage = 0
Zhang, Liangang's avatar
Zhang, Liangang committed
327
            self.init_attention_backend()
328

James Liu's avatar
James Liu committed
329
330
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
lukec's avatar
lukec committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
            # load draft config
            draft_model_config = ModelConfig.from_server_args(
                server_args,
                model_path=(server_args.speculative_draft_model_path),
                is_draft_model=True,
            )

            try:
                # get the aux layer from draft model config
                eagle_config = getattr(
                    draft_model_config.hf_config, "eagle_config", None
                )
                eagle_aux_hidden_state_layer_ids = eagle_config[
                    "eagle_aux_hidden_state_layer_ids"
                ]
            except:
                # if there is no aux layer, set to None
                eagle_aux_hidden_state_layer_ids = None

            self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
James Liu's avatar
James Liu committed
351

352
353
354
    def model_specific_adjustment(self):
        server_args = self.server_args

355
356
357
        if (
            server_args.attention_backend == "intel_amx"
            and server_args.device == "cpu"
358
            and not _is_cpu_amx_available
359
360
361
362
363
364
        ):
            logger.info(
                "The current platform does not support Intel AMX, will fallback to torch_native backend."
            )
            server_args.attention_backend = "torch_native"

365
        if server_args.attention_backend is None:
366
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
            Auto select the fastest attention backend.

369
370
371
372
373
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
374
375
                2.2 We will use Flashinfer backend on blackwell.
                2.3 Otherwise, we will use triton backend.
376
377
            """

378
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
379
                # MHA architecture
380
                if (
381
                    is_hopper_with_cuda_12_3()
382
383
384
385
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
386
387
                elif _is_hip:
                    server_args.attention_backend = "aiter"
388
389
                elif _is_npu:
                    server_args.attention_backend = "ascend"
390
391
392
393
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
394
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
395
                # MLA architecture
396
                if is_hopper_with_cuda_12_3():
397
                    server_args.attention_backend = "fa3"
398
399
                elif is_sm100_supported():
                    server_args.attention_backend = "flashinfer"
400
401
402
403
404
405
406
407
408
                elif _is_hip:
                    head_num = self.model_config.get_num_kv_heads(self.tp_size)
                    # TODO current aiter only support head number 16 or 128 head number
                    if (
                        head_num == 128 or head_num == 16
                    ) and self.spec_algorithm.is_none():
                        server_args.attention_backend = "aiter"
                    else:
                        server_args.attention_backend = "triton"
409
410
                elif _is_npu:
                    server_args.attention_backend = "ascend"
411
412
                else:
                    server_args.attention_backend = "triton"
413
414
415
            logger.info(
                f"Attention backend not set. Use {server_args.attention_backend} backend by default."
            )
416
        elif self.use_mla_backend:
417
            if server_args.device != "cpu":
418
                if server_args.attention_backend in [
419
                    "aiter",
420
421
422
423
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
424
                    "cutlass_mla",
425
                    "ascend",
426
                ]:
427
428
429
                    logger.info(
                        f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                    )
430
                else:
431
432
433
434
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
435
436
437
438
                if server_args.attention_backend != "intel_amx":
                    raise ValueError(
                        "MLA optimization not supported on CPU except for intel_amx backend."
                    )
439

440
441
442
443
444
445
446
447
448
449
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

450
        if server_args.enable_double_sparsity:
451
452
453
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
454
455
456
457
458
459
460
461
462
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
463
464
465
466
467
468
            if not self.is_multimodal_chunked_prefill_supported:
                server_args.chunked_prefill_size = -1
                logger.info(
                    f"Automatically turn of --chunked-prefill-size as it is not supported for "
                    f"{self.model_config.hf_config.model_type}"
                )
469

470
471
472
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
        elif self.page_size > 1:
473
            logger.info("Disable chunked prefix cache when page size > 1.")
474
475
476
            server_args.disable_chunked_prefix_cache = True

        if not server_args.disable_chunked_prefix_cache:
477
            logger.info("Chunked prefix cache is turned on.")
478

kk's avatar
kk committed
479
480
481
482
        if server_args.attention_backend == "aiter":
            if self.model_config.context_len > 8192:
                self.mem_fraction_static *= 0.85

483
    def init_torch_distributed(self):
484
        logger.info("Init torch distributed begin.")
485

486
487
488
489
490
491
492
493
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
494
495
        if self.device == "cuda":
            backend = "nccl"
496
        elif self.device == "xpu":
497
            backend = "xccl"
498
499
        elif self.device == "hpu":
            backend = "hccl"
500
501
        elif self.device == "cpu":
            backend = "gloo"
502
503
        elif self.device == "npu":
            backend = "hccl"
504

505
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
506
        if not self.server_args.enable_p2p_check:
507
508
            monkey_patch_p2p_access_check()

509
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
510
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
511
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
512
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
513
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
514
        set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
515
516

        if not self.is_draft_worker:
517
518
519
520
            if self.device == "cpu":
                if _is_cpu_amx_available:
                    # Bind OpenMP threads to CPU cores
                    torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
521
522
523
524

                    # Set local size to hint SGLang to use shared memory based AllReduce
                    os.environ["LOCAL_SIZE"] = str(self.tp_size)
                    torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
525
526
                else:
                    logger.warning(
527
                        "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
528
529
                    )

Mick's avatar
Mick committed
530
            # Only initialize the distributed environment on the target model worker.
531
532
            init_distributed_environment(
                backend=backend,
533
534
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
535
536
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
537
                timeout=self.server_args.dist_timeout,
538
            )
539
540
541
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
542
                duplicate_tp_group=self.server_args.enable_pdmux,
543
            )
544
545
546
547
548
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
549
                moe_dense_tp_size=self.server_args.moe_dense_tp_size,
550
                pp_size=self.server_args.pp_size,
551
            )
552

553
        min_per_gpu_memory = get_available_gpu_memory(
554
555
556
557
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
558
        )
559
        self.tp_group = get_tp_group()
560
        self.attention_tp_group = get_attention_tp_group()
561

562
        # Check memory for tensor parallelism
563
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
564
        if self.tp_size > 1:
565
            if min_per_gpu_memory < local_gpu_memory * 0.9:
566
567
568
569
570
571
572
573
574
575
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
576

577
578
579
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
580
        return min_per_gpu_memory
581

Lianmin Zheng's avatar
Lianmin Zheng committed
582
    def load_model(self):
583
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
584
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
585
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
586
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
587
588

        # This can reduce thread conflicts and speed up weight loading.
589
590
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
591
592
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
593
594
595
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
596
                self.server_args.dtype = "float16"
597
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
598
599
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
600

601
602
        set_cuda_arch()

603
        # Prepare the model config
604
605
606
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
607
            model_loader_extra_config=self.server_args.model_loader_extra_config,
608
        )
609
610
611
612
        if self.device == "cpu":
            self.model_config = adjust_config_with_unaligned_cpu_tp(
                self.model_config, self.load_config, self.tp_size
            )
613
614
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
615
616

        # Load the model
617
618
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
619
620
        monkey_patch_isinstance_for_vllm_base_layer()

621
        with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
622
623
624
625
626
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
627
        monkey_patch_vllm_parallel_state(reverse=True)
628
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
629

bjmsong's avatar
bjmsong committed
630
631
632
633
634
635
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
636
637
638
639
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
640
641
642
643
644
645
646
647
648
649
650
651
652
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

653
        # Parse other args
Hanming Lu's avatar
Hanming Lu committed
654
655
656
657
658
659
660
661
662
        self.sliding_window_size = None
        if hasattr(self.model, "get_attention_sliding_window_size"):
            self.sliding_window_size = self.model.get_attention_sliding_window_size()
        elif self.model_config.attention_chunk_size is not None:
            self.sliding_window_size = self.model_config.attention_chunk_size
            print(
                f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
            )

663
        self.dtype = self.model_config.dtype
664

665
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
666
        self.weight_load_mem_usage = before_avail_memory - after_avail_memory
667
        logger.info(
668
            f"Load weight end. "
669
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
670
            f"dtype={self.dtype}, "
671
            f"avail mem={after_avail_memory:.2f} GB, "
672
            f"mem usage={self.weight_load_mem_usage:.2f} GB."
673
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
674

675
676
677
678
679
680
681
682
683
684
685
686
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

687
    def update_expert_location(
688
689
690
        self,
        new_expert_location_metadata: ExpertLocationMetadata,
        update_layer_ids: List[int],
691
    ):
692
        self.expert_location_updater.update(
693
694
            self.model.routed_experts_weights_of_layer,
            new_expert_location_metadata,
695
            update_layer_ids=update_layer_ids,
696
697
698
699
            nnodes=self.server_args.nnodes,
            rank=self.tp_rank,
        )

700
701
702
703
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
704
        logger.info(
Chayenne's avatar
Chayenne committed
705
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
706
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
707
708
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
709
        target_device = torch.device(self.device)
710
        self.model_config.model_path = model_path
711
712
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
713
        # Only support DefaultModelLoader for now
714
715
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
716
717
            message = f"Failed to get model loader: {loader}."
            return False, message
718
719
720

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
721
                DefaultModelLoader.Source.init_new(config, self.model)
722
723
724
725
            )
            return iter

        def model_load_weights(model, iter):
726
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
727
728
            return model

729
        with set_default_torch_dtype(self.model_config.dtype):
730
            try:
731
                iter = get_weight_iter(self.model_config)
732
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
733
                message = f"Failed to get weights iterator: {e}."
734
735
736
737
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
738
739
740
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
741
742
                del iter
                gc.collect()
743
                iter = get_weight_iter(self.model_config)
744
745
746
747
748
749
750
751
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

752
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
753
        return True, "Succeeded to update model weights."
754

755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
783
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
784
785
786
        )

        try:
787
            self._model_update_group[group_name] = init_custom_process_group(
788
789
790
791
792
793
794
795
796
797
798
799
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

800
    def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
801
802
803
804
805
806
807
808
809
810
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """

811
812
813
814
        assert group_name in self._model_update_group, (
            f"Group {group_name} not in {list(self._model_update_group.keys())}. "
            "Please call `init_weights_update_group` first."
        )
815
816

        try:
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
            weights = []
            handles = []
            for name, dtype, shape in zip(names, dtypes, shapes):
                target_dtype = (
                    dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
                )
                weight = torch.empty(shape, dtype=target_dtype, device=self.device)
                handles.append(
                    torch.distributed.broadcast(
                        weight,
                        src=0,
                        group=self._model_update_group[group_name],
                        async_op=True,
                    )
                )
                weights.append((name, weight))
            for handle in handles:
                handle.wait()

            self.model.load_weights(weights)
            return True, f"Succeeded to update parameter online."
838
839
840
841
842
843
844
845
846
847

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

848
849
850
851
852
853
854
855
856
857
858
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
859
860
861
        elif load_format in self.server_args.custom_weight_loader:
            custom_loader = dynamic_import(load_format)
            custom_loader(self.model, named_tensors)
862
863
864
865
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
866
        return True, "Success"
867

868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

885
886
887
888
889
890
891
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
892
            lora_backend=self.server_args.lora_backend,
893
894
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
895
896
            max_lora_rank=self.server_args.max_lora_rank,
            target_modules=self.server_args.lora_target_modules,
897
        )
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
        result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
        if result.success:
            logger.info(
                f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
            )
        else:
            raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")

    def load_lora_adapter(self, lora_name: str, lora_path: str):
        """Load a new lora adapter from disk or huggingface."""

        logger.info(
            f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        result = self.lora_manager.load_lora_adapter(lora_name, lora_path)

        logger.info(
            f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result

    def unload_lora_adapter(self, lora_name: str):
        """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""

        logger.info(
            f"LoRA adapter unloading starts: name={lora_name}. "
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        result = self.lora_manager.unload_lora_adapter(lora_name)

        logger.info(
            f"LoRA adapter unloading completes: name={lora_name}. "
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        return result
939

940
    def profile_max_num_token(self, total_gpu_memory: int):
941
        available_gpu_memory = get_available_gpu_memory(
942
943
944
945
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
946
        )
947
948
949
950
951
        if self.is_draft_worker:
            num_layers = getattr(
                self.model_config.hf_config,
                "num_nextn_predict_layers",
                self.num_effective_layers,
952
            )
953
954
955
        else:
            num_layers = self.num_effective_layers
        if self.use_mla_backend:
956
957
            # FIXME: pipeline parallelism is not compatible with mla backend
            assert self.pp_size == 1
958
959
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
960
                * num_layers
961
                * torch._utils._element_size(self.kv_cache_dtype)
962
963
964
            )
        else:
            cell_size = (
965
                self.model_config.get_num_kv_heads(get_attention_tp_size())
966
                * self.model_config.head_dim
967
                * num_layers
968
                * 2
969
                * torch._utils._element_size(self.kv_cache_dtype)
970
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
971
972
973
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
974
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
975
976
        return max_num_token

tarinkk's avatar
tarinkk committed
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
    def set_num_token_hybrid(self):
        if (
            "Llama4ForConditionalGeneration"
            in self.model_config.hf_config.architectures
        ):
            temp_ratio = (
                (1 - self.is_hybrid)
                + self.is_hybrid
                * self.attention_chunk_size
                / self.model_config.context_len
            )
            self.swa_max_total_num_tokens = (
                4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.full_max_total_num_tokens = (
                4 * self.max_total_num_tokens
                - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
            )
            self.swa_max_total_num_tokens = int(
                self.swa_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.full_max_total_num_tokens = int(
                self.full_max_total_num_tokens
                // self.server_args.page_size
                * self.server_args.page_size
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens
        else:
Hanming Lu's avatar
Hanming Lu committed
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
            assert self.sliding_window_size is not None and self.sliding_window_size > 0
            full_attention_layer_ids = []
            swa_attention_layer_ids = []

            try:
                layers = self.model.model.layers
            except:
                try:
                    layers = self.model.language_model.model.layers
                except:
                    self.is_hybrid = False
                    return

            for layer in layers:
                if (
                    layer.self_attn.attn.sliding_window_size is None
                    or layer.self_attn.attn.sliding_window_size == -1
                ):
                    full_attention_layer_ids.append(layer.layer_id)
                else:
                    swa_attention_layer_ids.append(layer.layer_id)
            self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
            self.model_config.full_attention_layer_ids = full_attention_layer_ids

            # Algorithm:
            # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
            # - Find total # of tokens available across layers.
            # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
            total_tokens = (
                self.max_total_num_tokens * self.model_config.num_hidden_layers
            )
            full_layers_num = len(full_attention_layer_ids)
            swa_layers_num = len(swa_attention_layer_ids)
            swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio

            # Solve the equations:
            # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
            # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
            denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
            self.full_max_total_num_tokens = int(total_tokens / denominator)
            self.swa_max_total_num_tokens = int(
                self.full_max_total_num_tokens * swa_full_tokens_ratio
            )
            self.max_total_num_tokens = self.full_max_total_num_tokens

            logger.info(
                f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
tarinkk's avatar
tarinkk committed
1054
1055
            )

1056
    def init_memory_pool(
1057
1058
        self,
        total_gpu_memory: int,
1059
1060
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
1061
    ):
1062
1063
1064
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
1065
            if _is_hip:  # Using natively supported format
HAI's avatar
HAI committed
1066
1067
1068
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
1069
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
1070
1071
1072
            if _is_hip:  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e4m3fnuz
            else:
bjmsong's avatar
bjmsong committed
1073
                self.kv_cache_dtype = torch.float8_e4m3fn
1074
1075
1076
1077
1078
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

1079
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

1092
1093
1094
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

1095
1096
1097
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
1098
                max_num_reqs = self.server_args.max_num_reqs
1099
            else:
1100
1101
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
1102
1103
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
1104
1105
1106
1107
1108
1109
1110
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
1111
1112
                    + 100
                )
1113
1114
1115
1116
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
1117

1118
1119
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
1120
                logging.warning(
1121
1122
1123
1124
1125
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
1126

1127
1128
1129
1130
1131
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )
tarinkk's avatar
tarinkk committed
1132
1133
1134
1135
        # create token size for hybrid cache
        if self.is_hybrid:
            self.set_num_token_hybrid()

1136
        if self.max_total_num_tokens <= 0:
1137
            raise RuntimeError(
1138
                "Not enough memory. Please try to increase --mem-fraction-static."
1139
            )
1140

1141
        if self.req_to_token_pool is None:
Byron Hsu's avatar
Byron Hsu committed
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
            if self.server_args.disaggregation_mode == "decode":
                from sglang.srt.disaggregation.decode import DecodeReqToTokenPool

                # subscribe memory for pre-allocated requests
                # if max_num_reqs <= 32, we pre-allocate 2x requests
                pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
                self.req_to_token_pool = DecodeReqToTokenPool(
                    size=max_num_reqs,
                    max_context_len=self.model_config.context_len + 4,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    pre_alloc_size=pre_alloc_size,
                )
            else:
                self.req_to_token_pool = ReqToTokenPool(
                    size=max_num_reqs,
                    max_context_len=self.model_config.context_len + 4,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
1162
1163
1164
1165
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
        if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
            self.token_to_kv_pool = AscendTokenToKVPool(
                self.max_total_num_tokens,
                page_size=self.page_size,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
            self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
                self.max_total_num_tokens,
                page_size=self.page_size,
                dtype=self.kv_cache_dtype,
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=(
                    self.model_config.num_hidden_layers
                    if not self.is_draft_worker
                    else self.model_config.hf_config.num_nextn_predict_layers
                ),  # PP is not compatible with mla backend
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
                start_layer=self.start_layer,
                end_layer=self.end_layer,
            )
        elif self.use_mla_backend:
1195
1196
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1197
                page_size=self.page_size,
1198
                dtype=self.kv_cache_dtype,
1199
1200
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1201
1202
1203
1204
                layer_num=(
                    self.model_config.num_hidden_layers
                    if not self.is_draft_worker
                    else self.model_config.hf_config.num_nextn_predict_layers
1205
                ),  # PP is not compatible with mla backend
Zhang, Liangang's avatar
Zhang, Liangang committed
1206
                device=self.device,
1207
                enable_memory_saver=self.server_args.enable_memory_saver,
1208
1209
                start_layer=self.start_layer,
                end_layer=self.end_layer,
1210
            )
Shuo Yang's avatar
Shuo Yang committed
1211
1212
1213
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
1214
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
1215
                dtype=self.kv_cache_dtype,
1216
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
1217
                head_dim=self.model_config.head_dim,
1218
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
1219
1220
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
1221
                enable_memory_saver=self.server_args.enable_memory_saver,
1222
1223
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
1224
            )
1225
        else:
tarinkk's avatar
tarinkk committed
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
            if self.is_hybrid:
                self.token_to_kv_pool = SWAKVPool(
                    size=self.full_max_total_num_tokens,
                    size_swa=self.swa_max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
                    full_attention_layer_ids=self.model_config.full_attention_layer_ids,
                    enable_kvcache_transpose=False,
                    device=self.device,
                )
            else:
                self.token_to_kv_pool = MHATokenToKVPool(
Lianmin Zheng's avatar
Lianmin Zheng committed
1242
                    self.max_total_num_tokens,
tarinkk's avatar
tarinkk committed
1243
                    page_size=self.page_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
1244
                    dtype=self.kv_cache_dtype,
tarinkk's avatar
tarinkk committed
1245
1246
1247
1248
1249
                    head_num=self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    head_dim=self.model_config.head_dim,
                    layer_num=self.num_effective_layers,
Lianmin Zheng's avatar
Lianmin Zheng committed
1250
                    device=self.device,
tarinkk's avatar
tarinkk committed
1251
1252
1253
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    start_layer=self.start_layer,
                    end_layer=self.end_layer,
Lianmin Zheng's avatar
Lianmin Zheng committed
1254
                )
tarinkk's avatar
tarinkk committed
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272

        if self.token_to_kv_pool_allocator is None:
            if self.page_size == 1:
                if self.is_hybrid:
                    self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
                        self.full_max_total_num_tokens,
                        self.swa_max_total_num_tokens,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
                    )
                else:
                    self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                        self.max_total_num_tokens,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1273
            else:
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
                if _is_npu:
                    self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
                        self.max_total_num_tokens,
                        page_size=self.page_size,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
                    )
                else:
                    self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                        self.max_total_num_tokens,
                        page_size=self.page_size,
                        dtype=self.kv_cache_dtype,
                        device=self.device,
                        kvcache=self.token_to_kv_pool,
                    )
1290
1291
1292
        else:
            assert self.is_draft_worker

1293
        logger.info(
1294
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
1295
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
1296
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1297

Lianmin Zheng's avatar
Lianmin Zheng committed
1298
1299
1300
1301
1302
1303
1304
1305
1306
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

1307
1308
    def init_attention_backend(self):
        """Init attention kernel backend."""
1309
        if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1310
1311
1312
1313
1314
1315
            self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
        else:
            self.attn_backend = self._get_attention_backend()

    # TODO unify with 6338
    def _get_attention_backend(self):
1316
        if self.server_args.attention_backend == "flashinfer":
1317
1318
1319
1320
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
1321

1322
1323
1324
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
                    self.plan_stream_for_flashinfer = torch.cuda.Stream()
1325
                return FlashInferAttnBackend(self)
1326
1327
1328
1329
1330
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

1331
                return FlashInferMLAAttnBackend(self)
1332
1333
1334
        elif self.server_args.attention_backend == "aiter":
            from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend

1335
            return AiterAttnBackend(self)
1336
1337
1338
1339
        elif self.server_args.attention_backend == "ascend":
            from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend

            return AscendAttnBackend(self)
1340
1341
1342
1343
1344
1345
        elif self.server_args.attention_backend == "triton":
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
1346
1347
1348
1349
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

1350
                return DoubleSparseAttnBackend(self)
1351
            else:
1352
1353
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

1354
                return TritonAttnBackend(self)
1355
        elif self.server_args.attention_backend == "torch_native":
1356
1357
1358
1359
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

1360
            return TorchNativeAttnBackend(self)
lukec's avatar
lukec committed
1361
1362
1363
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

1364
            return FlashMLABackend(self)
1365
        elif self.server_args.attention_backend == "fa3":
1366
1367
1368
1369
            assert (
                torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
            ) or torch.cuda.get_device_capability()[0] == 9, (
                "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1370
1371
1372
1373
1374
1375
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

1376
            return FlashAttentionBackend(self)
1377
1378
1379
1380
1381
        elif self.server_args.attention_backend == "cutlass_mla":
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

1382
            return CutlassMLABackend(self)
1383
1384
1385
1386
1387
1388
1389
        elif self.server_args.attention_backend == "intel_amx":
            from sglang.srt.layers.attention.intel_amx_backend import (
                IntelAMXAttnBackend,
            )

            logger.info(f"Intel AMX attention backend is enabled.")
            return IntelAMXAttnBackend(self)
1390
1391
1392
1393
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
1394

Shuo Yang's avatar
Shuo Yang committed
1395
1396
1397
1398
1399
1400
1401
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1402
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1403
1404
1405
1406
1407
1408
1409
1410
1411
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1412
    def init_cuda_graphs(self):
1413
        """Capture cuda graphs."""
1414
        self.cuda_graph_runner = None
1415
        self.cuda_graph_mem_usage = 0
1416

1417
        if not self.is_generation:
1418
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1419
1420
            return

1421
1422
        if self.server_args.disable_cuda_graph:
            return
1423

1424
        tic = time.perf_counter()
1425
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1426
        logger.info(
1427
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1428
        )
1429
        self.cuda_graph_runner = CudaGraphRunner(self)
1430
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1431
        self.cuda_graph_mem_usage = before_mem - after_mem
1432
        logger.info(
1433
            f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1434
            f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1435
        )
1436

1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
    def init_threads_binding(self):
        omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
        if omp_cpuids == "all":
            cpu_ids_by_node = get_cpu_ids_by_node()
            n_numa_node = len(cpu_ids_by_node)

            assert self.tp_size <= n_numa_node, (
                f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
                f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
                f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
                f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
                f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
                f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
                f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
                f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
            )
            if self.tp_size < n_numa_node:
                logger.warning(
                    f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
                )
            self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
        else:
            self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]

1461
    def apply_torch_tp(self):
1462
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1463
1464
1465
1466
1467
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1468
1469
1470
    def forward_decode(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
1471
        self.attn_backend.init_forward_metadata(forward_batch)
1472
1473
1474
1475
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1476
        return self.model.forward(
1477
            forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
Lianmin Zheng's avatar
Lianmin Zheng committed
1478
1479
        )

1480
    def forward_extend(
1481
1482
1483
1484
1485
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
1486
1487
1488
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1502

1503
1504
1505
1506
1507
1508
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
1509
        return self.model.forward(
1510
1511
1512
1513
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
1514
1515
        )

1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
    def forward_split_prefill(
        self,
        forward_batch: ForwardBatch,
        reinit_attn_backend: bool = False,
        forward_count: int = 1,
    ) -> LogitsProcessorOutput:
        if forward_batch.split_index == 0 or reinit_attn_backend:
            self.attn_backend.init_forward_metadata(forward_batch)
        next_split_index = min(
            forward_batch.split_index + forward_count,
            self.model_config.num_hidden_layers,
        )
        ret = self.model.forward_split_prefill(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            (forward_batch.split_index, next_split_index),
        )
        forward_batch.split_index = next_split_index
        return ret

1537
    def forward(
1538
1539
1540
1541
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
1542
1543
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
1544
1545
1546
1547
1548
1549
1550
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
        self.forward_pass_id += 1

        with get_global_expert_distribution_recorder().with_forward_pass(
            self.forward_pass_id,
            forward_batch,
        ):
1551
            output = self._forward_raw(
1552
1553
1554
1555
1556
                forward_batch,
                skip_attn_backend_init,
                pp_proxy_tensors,
                reinit_attn_backend,
                split_forward_count,
1557
1558
            )

1559
        if self.eplb_manager is not None:
1560
            self.eplb_manager.on_forward_pass_end()
1561
1562
1563

        return output

1564
1565
1566
1567
1568
    def _forward_raw(
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool,
        pp_proxy_tensors: Optional[PPProxyTensors],
1569
1570
        reinit_attn_backend: bool = False,
        split_forward_count: int = 1,
1571
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1572
        can_run_cuda_graph = bool(
1573
1574
1575
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
1576
1577
        )
        if can_run_cuda_graph:
1578
            ret = self.cuda_graph_runner.replay(
1579
1580
1581
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1582
            )
1583
1584
        elif forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1585
        elif forward_batch.forward_mode.is_extend():
1586
            ret = self.forward_extend(
1587
1588
1589
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1590
            )
1591
1592
1593
1594
1595
1596
        elif forward_batch.forward_mode.is_split_prefill():
            ret = self.forward_split_prefill(
                forward_batch,
                reinit_attn_backend=reinit_attn_backend,
                forward_count=split_forward_count,
            )
Ke Bao's avatar
Ke Bao committed
1597
        elif forward_batch.forward_mode.is_idle():
1598
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
1599
        else:
1600
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1601

1602
1603
        return ret, can_run_cuda_graph

1604
1605
1606
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1607
        # Apply logit bias
1608
1609
1610
1611
1612
1613
1614
1615
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1616
1617
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
1638

1639
1640
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1641
1642
1643
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1644
            forward_batch.sampling_info,
1645
1646
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1647
            forward_batch.token_ids_logprobs,
1648
        )
1649
1650
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1651
1652
1653
1654
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
1655
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
1656
1657
        if rope_scaling is None:
            return False
1658
1659
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1660

1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1677
1678
1679
1680
1681
1682
1683
1684
1685

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
1686
1687
1688
        monkey_patch_torch_reductions()
        tensor = tensor.get(tp_rank)
    return tensor.to(torch.cuda.current_device())
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])