model_runner.py 45.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
Shuo Yang's avatar
Shuo Yang committed
18
import json
19
import logging
20
import os
21
import time
22
23
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25

import torch
26
import torch.distributed as dist
27
28
29
30
31

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
32
33
34
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
35
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
36
)
37
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
38
39
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
40
    get_attention_tp_size,
41
42
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
43
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
45
46
47
48
from sglang.srt.layers.quantization.deep_gemm import (
    _ENABLE_JIT_DEEPGEMM,
    update_deep_gemm_config,
)
49
from sglang.srt.layers.sampler import Sampler
50
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
51
from sglang.srt.lora.lora_manager import LoRAManager
52
from sglang.srt.managers.schedule_batch import global_server_args_dict
53
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
54
    DoubleSparseTokenToKVPool,
55
56
57
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
58
    TokenToKVPoolAllocator,
59
)
Lianmin Zheng's avatar
Lianmin Zheng committed
60
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
Yineng Zhang's avatar
Yineng Zhang committed
61
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
62
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
63
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
66
67
68
69
from sglang.srt.model_loader.loader import (
    DefaultModelLoader,
    device_loading_context,
    get_model_loader,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
70
from sglang.srt.model_loader.weight_utils import default_weight_loader
71
from sglang.srt.patch_torch import monkey_patch_torch_reductions
72
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
73
from sglang.srt.server_args import ServerArgs
74
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
75
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
76
from sglang.srt.utils import (
77
    MultiprocessingSerializer,
78
    enable_show_time_cost,
79
    get_available_gpu_memory,
80
    get_bool_env_var,
81
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
82
    is_cuda,
83
    is_fa3_default_architecture,
84
    is_flashinfer_available,
HAI's avatar
HAI committed
85
    is_hip,
86
    is_hopper_with_cuda_12_3,
87
    is_no_spec_infer_or_topk_one,
88
    monkey_patch_p2p_access_check,
89
    monkey_patch_vllm_gguf_config,
90
    set_cpu_offload_max_bytes,
91
    set_cuda_arch,
92
)
93

Lianmin Zheng's avatar
Lianmin Zheng committed
94
# Use a small KV cache pool size for tests in CI
95
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
96
97

# Detect stragger ranks in model loading
98
99
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
100
101
logger = logging.getLogger(__name__)

102

Lianmin Zheng's avatar
Lianmin Zheng committed
103
class ModelRunner:
104
105
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
106
107
    def __init__(
        self,
108
        model_config: ModelConfig,
109
110
111
112
113
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
114
        server_args: ServerArgs,
115
        is_draft_worker: bool = False,
116
117
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
118
    ):
119
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
120
121
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
122
        self.device = server_args.device
123
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
124
125
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
126
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
127
        self.server_args = server_args
128
        self.is_draft_worker = is_draft_worker
129
130
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
131
        self.should_log = tp_rank == 0
132
133
134
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
135
        self.page_size = server_args.page_size
136
137
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Baizhou Zhang's avatar
Baizhou Zhang committed
138
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
139
        self.attention_chunk_size = model_config.attention_chunk_size
Ke Bao's avatar
Ke Bao committed
140

141
        # Model-specific adjustment
142
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
143

144
145
        if server_args.show_time_cost:
            enable_show_time_cost()
146
147

        # Global vars
148
149
        global_server_args_dict.update(
            {
150
151
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
152
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
153
                "torchao_config": server_args.torchao_config,
154
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
155
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
156
                "enable_ep_moe": server_args.enable_ep_moe,
157
                "enable_deepep_moe": server_args.enable_deepep_moe,
158
                "deepep_mode": server_args.deepep_mode,
159
                "device": server_args.device,
160
161
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
162
                "disable_radix_cache": server_args.disable_radix_cache,
163
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
164
                "moe_dense_tp_size": server_args.moe_dense_tp_size,
165
166
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
167
                "n_share_experts_fusion": server_args.n_share_experts_fusion,
168
                "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
169
                "use_mla_backend": self.use_mla_backend,
170
171
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
172

173
        # CPU offload
174
175
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

176
        # Get memory before model loading
177
        min_per_gpu_memory = self.init_torch_distributed()
178

179
180
181
182
        # Update deep gemm configure
        if _ENABLE_JIT_DEEPGEMM:
            update_deep_gemm_config(gpu_id, server_args)

Lianmin Zheng's avatar
Lianmin Zheng committed
183
        # If it is a draft model, tp_group can be different
184
185
186
187
        self.initialize(min_per_gpu_memory)

    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
188
189
190
191
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

192
        # Load the model
193
        self.sampler = Sampler()
194
        self.load_model()
195

196
        # Apply torchao quantization
197
198
199
200
201
202
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
203

204
        # Apply torch TP if the model supports it
205
206
207
208
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

209
        # Init lora
210
211
        if server_args.lora_paths is not None:
            self.init_lora_manager()
212
213

        # Init memory pool and attention backends
214
215
        self.init_memory_pool(
            min_per_gpu_memory,
216
            server_args.max_running_requests,
217
218
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
219
220
221
222
223
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
224
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
225
            self.init_attention_backend()
226

James Liu's avatar
James Liu committed
227
228
229
230
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
            self.model.set_eagle3_layers_to_capture()

231
232
233
    def model_specific_adjustment(self):
        server_args = self.server_args

234
        if server_args.attention_backend is None:
235
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
236
237
            Auto select the fastest attention backend.

238
239
240
241
242
243
244
245
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
                2.2 Otherwise, we will use triton backend.
            """

246
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
247
                # MHA architecture
248
249
250
251
252
253
254
255
256
257
                if (
                    is_hopper_with_cuda_12_3()
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
258
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
259
                # MLA architecture
260
                if is_hopper_with_cuda_12_3():
261
                    server_args.attention_backend = "fa3"
262
263
                else:
                    server_args.attention_backend = "triton"
264
265
266
267
            logger.info(
                f"Attention backend not set. Use {server_args.attention_backend} backend by default."
            )
        elif self.use_mla_backend:
268
            if server_args.device != "cpu":
269
270
271
272
273
                if server_args.attention_backend in [
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
274
                    "cutlass_mla",
275
                ]:
276
                    logger.info(
277
                        f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
278
                    )
279
                else:
280
281
282
283
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
284
                raise ValueError("MLA optimization not supported on CPU.")
285

286
287
288
289
290
291
292
293
294
295
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

296
297
298
299
300
301
302
303
304
305
306
307
308
        if server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
Mick's avatar
Mick committed
309
            self.mem_fraction_static *= 0.90
310
311
312
313
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
            )
Mick's avatar
Mick committed
314
315
316
317
            logger.info(
                "Automatically turn off --chunked-prefill-size for multimodal model."
            )
            server_args.chunked_prefill_size = -1
318

319
320
321
322
323
324
325
326
327
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
        elif self.page_size > 1:
            logger.info("Disable chunked prefix cache when page size > 1.")
            server_args.disable_chunked_prefix_cache = True

        if not server_args.disable_chunked_prefix_cache:
            logger.info("Chunked prefix cache is turned on.")

328
    def init_torch_distributed(self):
329
        logger.info("Init torch distributed begin.")
330

331
332
333
334
335
336
337
338
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
339
340
        if self.device == "cuda":
            backend = "nccl"
341
        elif self.device == "xpu":
342
            backend = "xccl"
343
344
        elif self.device == "hpu":
            backend = "hccl"
345
346
        elif self.device == "cpu":
            backend = "gloo"
347

348
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
349
        if not self.server_args.enable_p2p_check:
350
351
            monkey_patch_p2p_access_check()

352
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
353
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
354
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
355
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
356
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
357
358

        if not self.is_draft_worker:
Mick's avatar
Mick committed
359
            # Only initialize the distributed environment on the target model worker.
360
361
362
363
364
365
            init_distributed_environment(
                backend=backend,
                world_size=self.tp_size,
                rank=self.tp_rank,
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
366
                timeout=self.server_args.dist_timeout,
367
368
            )
            initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
369
370
371
372
373
374
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
            )
375

376
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
377
            self.device, self.gpu_id, distributed=self.tp_size > 1
378
        )
379
        self.tp_group = get_tp_group()
380
        self.attention_tp_group = get_attention_tp_group()
381

382
        # Check memory for tensor parallelism
383
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
384
        if self.tp_size > 1:
385
            if min_per_gpu_memory < local_gpu_memory * 0.9:
386
387
388
389
390
391
392
393
394
395
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
396

397
398
399
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
400
        return min_per_gpu_memory
401

Lianmin Zheng's avatar
Lianmin Zheng committed
402
    def load_model(self):
403
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
404
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
405
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
406
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
407
408

        # This can reduce thread conflicts and speed up weight loading.
409
410
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
411
412
413
414
415
416
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
417
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
418
419
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
420

421
422
        set_cuda_arch()

423
        # Prepare the model config
424
425
426
427
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
428
429
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
430
431

        # Load the model
432
433
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
434
435
        monkey_patch_isinstance_for_vllm_base_layer()

436
437
438
439
440
441
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
442
        monkey_patch_vllm_parallel_state(reverse=True)
443
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
444

bjmsong's avatar
bjmsong committed
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

468
        # Parse other args
469
        self.sliding_window_size = (
470
471
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
472
473
            else None
        )
474
        self.dtype = self.model_config.dtype
475

476
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
477
        logger.info(
478
            f"Load weight end. "
479
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
480
            f"dtype={self.dtype}, "
481
482
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
483
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
484

485
486
487
488
489
490
491
492
493
494
495
496
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

497
498
499
500
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
501
        logger.info(
Chayenne's avatar
Chayenne committed
502
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
503
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
504
505
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
506
        target_device = torch.device(self.device)
507
        self.model_config.model_path = model_path
508
509
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
510
        # Only support DefaultModelLoader for now
511
512
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
513
514
            message = f"Failed to get model loader: {loader}."
            return False, message
515
516
517

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
518
519
520
521
522
523
524
                DefaultModelLoader.Source(
                    config.model_path,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
525
526
527
528
529
530
531
532
533
534
535
536
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

537
        with set_default_torch_dtype(self.model_config.dtype):
538
            try:
539
                iter = get_weight_iter(self.model_config)
540
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
541
                message = f"Failed to get weights iterator: {e}."
542
543
544
545
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
546
547
548
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
549
550
                del iter
                gc.collect()
551
                iter = get_weight_iter(self.model_config)
552
553
554
555
556
557
558
559
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

560
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
561
        return True, "Succeeded to update model weights."
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
591
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
657
        return True, "Success"
658

659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

676
677
678
679
680
681
682
683
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
684
            lora_backend=self.server_args.lora_backend,
685
686
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
687
688
689
        )
        logger.info("LoRA manager ready.")

690
    def profile_max_num_token(self, total_gpu_memory: int):
691
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
692
            self.device, self.gpu_id, distributed=self.tp_size > 1
693
        )
694
        if self.use_mla_backend:
695
696
697
698
699
            num_layers = (
                self.model_config.num_hidden_layers
                if not self.is_draft_worker
                else self.model_config.hf_config.num_nextn_predict_layers
            )
700
701
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
702
                * num_layers
703
                * torch._utils._element_size(self.kv_cache_dtype)
704
705
706
            )
        else:
            cell_size = (
707
                self.model_config.get_num_kv_heads(get_attention_tp_size())
708
709
710
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
711
                * torch._utils._element_size(self.kv_cache_dtype)
712
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
713
714
715
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
716
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
717
718
        return max_num_token

719
    def init_memory_pool(
720
721
        self,
        total_gpu_memory: int,
722
723
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
724
    ):
725
726
727
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
728
            if is_hip():  # Using natively supported format
HAI's avatar
HAI committed
729
730
731
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
732
733
734
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
735
736
737
738
739
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

740
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
741
742
743
744
745
746
747
748
749
750
751
752

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

753
754
755
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

756
757
758
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
759
                max_num_reqs = self.server_args.max_num_reqs
760
            else:
761
762
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
763
764
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
765
766
767
768
769
770
771
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
772
773
                    + 100
                )
774
775
776
777
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
778

779
780
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
781
                logging.warning(
782
783
784
785
786
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
787

788
789
790
791
792
793
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )

794
        if self.max_total_num_tokens <= 0:
795
            raise RuntimeError(
796
                "Not enough memory. Please try to increase --mem-fraction-static."
797
            )
798

799
800
801
802
803
804
805
806
807
808
809
        if self.req_to_token_pool is None:
            self.req_to_token_pool = ReqToTokenPool(
                size=max_num_reqs + 1,
                max_context_len=self.model_config.context_len + 4,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

810
        if self.use_mla_backend:
811
812
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
813
                page_size=self.page_size,
814
                dtype=self.kv_cache_dtype,
815
816
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
817
818
819
820
821
                layer_num=(
                    self.model_config.num_hidden_layers
                    if not self.is_draft_worker
                    else self.model_config.hf_config.num_nextn_predict_layers
                ),
Zhang, Liangang's avatar
Zhang, Liangang committed
822
                device=self.device,
823
                enable_memory_saver=self.server_args.enable_memory_saver,
824
            )
Shuo Yang's avatar
Shuo Yang committed
825
826
827
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
828
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
829
                dtype=self.kv_cache_dtype,
830
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
831
832
833
834
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
835
                enable_memory_saver=self.server_args.enable_memory_saver,
Shuo Yang's avatar
Shuo Yang committed
836
            )
837
838
839
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
840
                page_size=self.page_size,
841
                dtype=self.kv_cache_dtype,
842
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
843
844
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
845
                device=self.device,
846
                enable_memory_saver=self.server_args.enable_memory_saver,
847
            )
848
849

        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
            if self.page_size == 1:
                self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
            else:
                self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
865
866
867
        else:
            assert self.is_draft_worker

868
        logger.info(
869
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
870
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
871
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
872

Lianmin Zheng's avatar
Lianmin Zheng committed
873
874
875
876
877
878
879
880
881
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

882
883
    def init_attention_backend(self):
        """Init attention kernel backend."""
884
        if self.server_args.attention_backend == "flashinfer":
885
886
887
888
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
889

890
891
892
893
894
895
896
897
898
899
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
                    self.plan_stream_for_flashinfer = torch.cuda.Stream()
                self.attn_backend = FlashInferAttnBackend(self)
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

                self.attn_backend = FlashInferMLAAttnBackend(self)
900
901
902
903
904
905
906
907
908
909
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
910
911
912
913
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

914
                self.attn_backend = DoubleSparseAttnBackend(self)
915
            else:
916
917
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

918
919
                self.attn_backend = TritonAttnBackend(self)
        elif self.server_args.attention_backend == "torch_native":
920
921
922
923
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

924
            self.attn_backend = TorchNativeAttnBackend(self)
lukec's avatar
lukec committed
925
926
927
928
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

            self.attn_backend = FlashMLABackend(self)
929
930
931
932
933
934
935
936
937
938
        elif self.server_args.attention_backend == "fa3":
            assert torch.cuda.get_device_capability()[0] >= 9, (
                "FlashAttention v3 Backend requires SM>=90. "
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

            self.attn_backend = FlashAttentionBackend(self)
939
940
941
942
943
944
        elif self.server_args.attention_backend == "cutlass_mla":
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

            self.attn_backend = CutlassMLABackend(self)
945
946
947
948
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
949

Shuo Yang's avatar
Shuo Yang committed
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

967
    def init_cuda_graphs(self):
968
        """Capture cuda graphs."""
969
970
        self.cuda_graph_runner = None

971
972
973
974
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

975
976
        if self.server_args.disable_cuda_graph:
            return
977

978
        tic = time.time()
979
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
980
        logger.info(
981
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
982
        )
983
        self.cuda_graph_runner = CudaGraphRunner(self)
984
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
985
        logger.info(
986
            f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
987
            f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
988
        )
989

990
991
992
993
994
995
996
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

997
    def forward_decode(self, forward_batch: ForwardBatch):
998
        self.attn_backend.init_forward_metadata(forward_batch)
999
        return self.model.forward(
1000
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1001
1002
        )

1003
1004
1005
1006
1007
1008
    def forward_extend(
        self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
    ):
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1009
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
1021
1022
1023
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
1024
1025
1026
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
1027
1028
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1029

Ke Bao's avatar
Ke Bao committed
1030
1031
1032
1033
1034
    def forward_idle(self, forward_batch: ForwardBatch):
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

1035
1036
1037
    def forward(
        self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
    ) -> LogitsProcessorOutput:
1038
1039
1040
1041
1042
        if (
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
        ):
1043
1044
1045
            return self.cuda_graph_runner.replay(
                forward_batch, skip_attn_backend_init=skip_attn_backend_init
            )
1046

1047
1048
1049
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
1050
1051
1052
            return self.forward_extend(
                forward_batch, skip_attn_backend_init=skip_attn_backend_init
            )
Ke Bao's avatar
Ke Bao committed
1053
1054
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1055
        else:
1056
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1057

1058
1059
1060
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1061
        # Apply logit bias
1062
1063
1064
1065
1066
1067
1068
1069
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1070
1071
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )

        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1095
1096
1097
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1098
            forward_batch.sampling_info,
1099
1100
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1101
            forward_batch.token_ids_logprobs,
1102
        )
1103
1104
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1105
1106
1107
1108
1109
1110
1111
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
1112
1113
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1114

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1131
1132
1133
1134
1135
1136
1137
1138
1139

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
1140
1141
1142
        monkey_patch_torch_reductions()
        tensor = tensor.get(tp_rank)
    return tensor.to(torch.cuda.current_device())
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])