model_runner.py 47.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import collections
17
import datetime
18
import gc
19
import inspect
Shuo Yang's avatar
Shuo Yang committed
20
import json
21
import logging
22
import os
23
import time
24
25
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27

import torch
28
import torch.distributed as dist
29
30
31
32
33

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
34
35
36
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
37
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
38
)
39
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
40
41
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
42
    get_attention_tp_size,
43
44
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
45
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
46
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
47
48
49
50
from sglang.srt.layers.quantization.deep_gemm import (
    _ENABLE_JIT_DEEPGEMM,
    update_deep_gemm_config,
)
51
from sglang.srt.layers.sampler import Sampler
52
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
53
from sglang.srt.lora.lora_manager import LoRAManager
54
from sglang.srt.managers.schedule_batch import global_server_args_dict
55
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
56
    DoubleSparseTokenToKVPool,
57
58
59
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
60
    TokenToKVPoolAllocator,
61
)
Lianmin Zheng's avatar
Lianmin Zheng committed
62
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
Yineng Zhang's avatar
Yineng Zhang committed
63
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
64
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
65
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68
69
70
71
from sglang.srt.model_loader.loader import (
    DefaultModelLoader,
    device_loading_context,
    get_model_loader,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
72
from sglang.srt.model_loader.weight_utils import default_weight_loader
73
from sglang.srt.patch_torch import monkey_patch_torch_reductions
74
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
75
from sglang.srt.server_args import ServerArgs
76
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
77
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
78
from sglang.srt.utils import (
79
    MultiprocessingSerializer,
80
    enable_show_time_cost,
81
    get_available_gpu_memory,
82
    get_bool_env_var,
83
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
84
    is_cuda,
85
    is_fa3_default_architecture,
86
    is_flashinfer_available,
HAI's avatar
HAI committed
87
    is_hip,
88
    is_hopper_with_cuda_12_3,
89
    is_no_spec_infer_or_topk_one,
90
    monkey_patch_p2p_access_check,
91
    monkey_patch_vllm_gguf_config,
92
    set_cpu_offload_max_bytes,
93
    set_cuda_arch,
94
)
95

Lianmin Zheng's avatar
Lianmin Zheng committed
96
# Use a small KV cache pool size for tests in CI
97
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
98
99

# Detect stragger ranks in model loading
100
101
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
102
103
logger = logging.getLogger(__name__)

104

Lianmin Zheng's avatar
Lianmin Zheng committed
105
class ModelRunner:
106
107
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
    def __init__(
        self,
110
        model_config: ModelConfig,
111
112
113
114
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
115
116
        pp_rank: int,
        pp_size: int,
117
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
118
        server_args: ServerArgs,
119
        is_draft_worker: bool = False,
120
121
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
122
    ):
123
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
124
125
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
126
        self.device = server_args.device
127
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
        self.tp_rank = tp_rank
        self.tp_size = tp_size
130
131
        self.pp_rank = pp_rank
        self.pp_size = pp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
132
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
133
        self.server_args = server_args
134
        self.is_draft_worker = is_draft_worker
135
136
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
137
        self.should_log = tp_rank == 0
138
139
140
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
141
        self.page_size = server_args.page_size
142
143
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Baizhou Zhang's avatar
Baizhou Zhang committed
144
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
145
        self.attention_chunk_size = model_config.attention_chunk_size
Ke Bao's avatar
Ke Bao committed
146

147
        # Model-specific adjustment
148
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
149

150
151
        if server_args.show_time_cost:
            enable_show_time_cost()
152
153

        # Global vars
154
155
        global_server_args_dict.update(
            {
156
                "attention_backend": server_args.attention_backend,
157
158
159
160
161
162
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "deepep_mode": server_args.deepep_mode,
                "device": server_args.device,
                "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
                "disable_radix_cache": server_args.disable_radix_cache,
163
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
164
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
165
                "enable_ep_moe": server_args.enable_ep_moe,
166
                "enable_deepep_moe": server_args.enable_deepep_moe,
167
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
168
                "moe_dense_tp_size": server_args.moe_dense_tp_size,
169
                "n_share_experts_fusion": server_args.n_share_experts_fusion,
170
171
172
173
174
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
                "torchao_config": server_args.torchao_config,
                "sampling_backend": server_args.sampling_backend,
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
175
                "use_mla_backend": self.use_mla_backend,
176
                "mm_attention_backend": server_args.mm_attention_backend,
177
178
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
179

180
        # CPU offload
181
182
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

183
        # Get memory before model loading
184
        min_per_gpu_memory = self.init_torch_distributed()
185

186
187
188
189
        # Update deep gemm configure
        if _ENABLE_JIT_DEEPGEMM:
            update_deep_gemm_config(gpu_id, server_args)

Lianmin Zheng's avatar
Lianmin Zheng committed
190
        # If it is a draft model, tp_group can be different
191
192
        self.initialize(min_per_gpu_memory)

193
194
195
196
197
        # temporary cached values
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )

198
199
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
200
201
202
203
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

204
        # Load the model
205
        self.sampler = Sampler()
206
        self.load_model()
207

208
209
210
211
212
213
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(
            self.model, "end_layer", self.model_config.num_hidden_layers
        )
        self.num_effective_layers = self.end_layer - self.start_layer

214
        # Apply torchao quantization
215
216
217
218
219
220
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
221

222
        # Apply torch TP if the model supports it
223
224
225
226
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

applesaucethebun's avatar
applesaucethebun committed
227
        # Init LoRA
228
229
        if server_args.lora_paths is not None:
            self.init_lora_manager()
230
231

        # Init memory pool and attention backends
232
233
        self.init_memory_pool(
            min_per_gpu_memory,
234
            server_args.max_running_requests,
235
236
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
237
238
239
240
241
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
242
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
243
            self.init_attention_backend()
244

James Liu's avatar
James Liu committed
245
246
247
248
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
            self.model.set_eagle3_layers_to_capture()

249
250
251
    def model_specific_adjustment(self):
        server_args = self.server_args

252
        if server_args.attention_backend is None:
253
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
254
255
            Auto select the fastest attention backend.

256
257
258
259
260
261
262
263
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
                2.2 Otherwise, we will use triton backend.
            """

264
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
265
                # MHA architecture
266
                if (
267
                    is_hopper_with_cuda_12_3()
268
269
270
271
272
273
274
275
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
276
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
277
                # MLA architecture
278
                if is_hopper_with_cuda_12_3():
279
                    server_args.attention_backend = "fa3"
280
281
                else:
                    server_args.attention_backend = "triton"
282
283
284
285
            if self.should_log:
                logger.info(
                    f"Attention backend not set. Use {server_args.attention_backend} backend by default."
                )
286
        elif self.use_mla_backend:
287
            if server_args.device != "cpu":
288
289
290
291
292
                if server_args.attention_backend in [
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
293
                    "cutlass_mla",
294
                ]:
295
296
297
298
                    if self.should_log:
                        logger.info(
                            f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                        )
299
                else:
300
301
302
303
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
304
                raise ValueError("MLA optimization not supported on CPU.")
305

306
307
308
309
310
311
312
313
314
315
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

316
        if server_args.enable_double_sparsity:
317
318
319
320
            if self.should_log:
                logger.info(
                    "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
                )
321
322
323
324
325
326
327
328
329
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
Mick's avatar
Mick committed
330
            self.mem_fraction_static *= 0.90
331
332
333
334
335
336
337
338
            if self.should_log:
                logger.info(
                    f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                    f"because this is a multimodal model."
                )
                logger.info(
                    "Automatically turn off --chunked-prefill-size for multimodal model."
                )
Mick's avatar
Mick committed
339
            server_args.chunked_prefill_size = -1
340

341
342
343
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
        elif self.page_size > 1:
344
345
            if self.should_log:
                logger.info("Disable chunked prefix cache when page size > 1.")
346
347
348
            server_args.disable_chunked_prefix_cache = True

        if not server_args.disable_chunked_prefix_cache:
349
350
            if self.should_log:
                logger.info("Chunked prefix cache is turned on.")
351

352
    def init_torch_distributed(self):
353
        logger.info("Init torch distributed begin.")
354

355
356
357
358
359
360
361
362
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
363
364
        if self.device == "cuda":
            backend = "nccl"
365
        elif self.device == "xpu":
366
            backend = "xccl"
367
368
        elif self.device == "hpu":
            backend = "hccl"
369
370
        elif self.device == "cpu":
            backend = "gloo"
371
372
        elif self.device == "npu":
            backend = "hccl"
373

374
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
375
        if not self.server_args.enable_p2p_check:
376
377
            monkey_patch_p2p_access_check()

378
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
379
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
380
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
381
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
382
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
383
384

        if not self.is_draft_worker:
Mick's avatar
Mick committed
385
            # Only initialize the distributed environment on the target model worker.
386
387
            init_distributed_environment(
                backend=backend,
388
389
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
390
391
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
392
                timeout=self.server_args.dist_timeout,
393
            )
394
395
396
397
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
            )
398
399
400
401
402
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
403
                pp_size=self.server_args.pp_size,
404
            )
405

406
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
407
            self.device, self.gpu_id, distributed=self.tp_size > 1
408
        )
409
        self.tp_group = get_tp_group()
410
        self.attention_tp_group = get_attention_tp_group()
411

412
        # Check memory for tensor parallelism
413
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
414
        if self.tp_size > 1:
415
            if min_per_gpu_memory < local_gpu_memory * 0.9:
416
417
418
419
420
421
422
423
424
425
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
426

427
428
429
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
430
        return min_per_gpu_memory
431

Lianmin Zheng's avatar
Lianmin Zheng committed
432
    def load_model(self):
433
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
434
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
435
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
436
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
437
438

        # This can reduce thread conflicts and speed up weight loading.
439
440
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
441
442
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
443
444
445
446
                if self.should_log:
                    logger.info(
                        "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                    )
Zhang, Liangang's avatar
Zhang, Liangang committed
447
                self.server_args.dtype = "float16"
448
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
449
450
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
451

452
453
        set_cuda_arch()

454
        # Prepare the model config
455
456
457
458
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
459
460
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
461
462

        # Load the model
463
464
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
465
466
        monkey_patch_isinstance_for_vllm_base_layer()

467
468
469
470
471
472
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
473
        monkey_patch_vllm_parallel_state(reverse=True)
474
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
475

bjmsong's avatar
bjmsong committed
476
477
478
479
480
481
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
482
483
484
485
486
                    if self.should_log:
                        logger.info(
                            "Loaded KV cache scaling factors from %s",
                            self.server_args.quantization_param_path,
                        )
bjmsong's avatar
bjmsong committed
487
488
489
490
491
492
493
494
495
496
497
498
499
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

500
        # Parse other args
501
        self.sliding_window_size = (
502
503
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
504
505
            else None
        )
506
        self.dtype = self.model_config.dtype
507

508
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
509
        logger.info(
510
            f"Load weight end. "
511
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
512
            f"dtype={self.dtype}, "
513
514
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
515
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
516

517
518
519
520
521
522
523
524
525
526
527
528
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

529
530
531
532
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
533
        logger.info(
Chayenne's avatar
Chayenne committed
534
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
535
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
536
537
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
538
        target_device = torch.device(self.device)
539
        self.model_config.model_path = model_path
540
541
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
542
        # Only support DefaultModelLoader for now
543
544
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
545
546
            message = f"Failed to get model loader: {loader}."
            return False, message
547
548
549

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
550
551
552
553
554
555
556
                DefaultModelLoader.Source(
                    config.model_path,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
557
558
559
560
            )
            return iter

        def model_load_weights(model, iter):
561
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
562
563
            return model

564
        with set_default_torch_dtype(self.model_config.dtype):
565
            try:
566
                iter = get_weight_iter(self.model_config)
567
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
568
                message = f"Failed to get weights iterator: {e}."
569
570
571
572
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
573
574
575
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
576
577
                del iter
                gc.collect()
578
                iter = get_weight_iter(self.model_config)
579
580
581
582
583
584
585
586
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

587
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
588
        return True, "Succeeded to update model weights."
589

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
618
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
684
        return True, "Success"
685

686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

703
704
705
706
707
708
709
710
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
711
            lora_backend=self.server_args.lora_backend,
712
713
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
714
715
716
        )
        logger.info("LoRA manager ready.")

717
    def profile_max_num_token(self, total_gpu_memory: int):
718
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
719
            self.device, self.gpu_id, distributed=self.tp_size > 1
720
        )
721
        if self.use_mla_backend:
722
723
724
725
726
            num_layers = (
                self.model_config.num_hidden_layers
                if not self.is_draft_worker
                else self.model_config.hf_config.num_nextn_predict_layers
            )
727
728
            # FIXME: pipeline parallelism is not compatible with mla backend
            assert self.pp_size == 1
729
730
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
731
                * num_layers
732
                * torch._utils._element_size(self.kv_cache_dtype)
733
734
735
            )
        else:
            cell_size = (
736
                self.model_config.get_num_kv_heads(get_attention_tp_size())
737
                * self.model_config.head_dim
738
                * self.num_effective_layers
739
                * 2
740
                * torch._utils._element_size(self.kv_cache_dtype)
741
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
742
743
744
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
745
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
746
747
        return max_num_token

748
    def init_memory_pool(
749
750
        self,
        total_gpu_memory: int,
751
752
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
753
    ):
754
755
756
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
757
            if is_hip():  # Using natively supported format
HAI's avatar
HAI committed
758
759
760
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
761
762
763
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
764
765
766
767
768
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

769
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
770
771
772
773
774
775
776
777
778
779
780
781

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

782
783
784
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

785
786
787
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
788
                max_num_reqs = self.server_args.max_num_reqs
789
            else:
790
791
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
792
793
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
794
795
796
797
798
799
800
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
801
802
                    + 100
                )
803
804
805
806
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
807

808
809
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
810
                logging.warning(
811
812
813
814
815
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
816

817
818
819
820
821
822
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )

823
        if self.max_total_num_tokens <= 0:
824
            raise RuntimeError(
825
                "Not enough memory. Please try to increase --mem-fraction-static."
826
            )
827

828
829
830
831
832
833
834
835
836
837
838
        if self.req_to_token_pool is None:
            self.req_to_token_pool = ReqToTokenPool(
                size=max_num_reqs + 1,
                max_context_len=self.model_config.context_len + 4,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

839
        if self.use_mla_backend:
840
841
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
842
                page_size=self.page_size,
843
                dtype=self.kv_cache_dtype,
844
845
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
846
847
848
849
                layer_num=(
                    self.model_config.num_hidden_layers
                    if not self.is_draft_worker
                    else self.model_config.hf_config.num_nextn_predict_layers
850
                ),  # PP is not compatible with mla backend
Zhang, Liangang's avatar
Zhang, Liangang committed
851
                device=self.device,
852
                enable_memory_saver=self.server_args.enable_memory_saver,
853
854
                start_layer=self.start_layer,
                end_layer=self.end_layer,
855
            )
Shuo Yang's avatar
Shuo Yang committed
856
857
858
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
859
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
860
                dtype=self.kv_cache_dtype,
861
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
862
                head_dim=self.model_config.head_dim,
863
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
864
865
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
866
                enable_memory_saver=self.server_args.enable_memory_saver,
867
868
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
869
            )
870
871
872
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
873
                page_size=self.page_size,
874
                dtype=self.kv_cache_dtype,
875
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
876
                head_dim=self.model_config.head_dim,
877
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
878
                device=self.device,
879
                enable_memory_saver=self.server_args.enable_memory_saver,
880
881
                start_layer=self.start_layer,
                end_layer=self.end_layer,
882
            )
883
884

        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
            if self.page_size == 1:
                self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
            else:
                self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
900
901
902
        else:
            assert self.is_draft_worker

903
        logger.info(
904
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
905
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
906
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
907

Lianmin Zheng's avatar
Lianmin Zheng committed
908
909
910
911
912
913
914
915
916
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

917
918
    def init_attention_backend(self):
        """Init attention kernel backend."""
919
        if self.server_args.attention_backend == "flashinfer":
920
921
922
923
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
924

925
926
927
928
929
930
931
932
933
934
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
                    self.plan_stream_for_flashinfer = torch.cuda.Stream()
                self.attn_backend = FlashInferAttnBackend(self)
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

                self.attn_backend = FlashInferMLAAttnBackend(self)
935
936
937
938
939
940
941
942
943
944
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
945
946
947
948
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

949
                self.attn_backend = DoubleSparseAttnBackend(self)
950
            else:
951
952
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

953
954
                self.attn_backend = TritonAttnBackend(self)
        elif self.server_args.attention_backend == "torch_native":
955
956
957
958
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

959
            self.attn_backend = TorchNativeAttnBackend(self)
lukec's avatar
lukec committed
960
961
962
963
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

            self.attn_backend = FlashMLABackend(self)
964
        elif self.server_args.attention_backend == "fa3":
965
966
967
968
            assert (
                torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
            ) or torch.cuda.get_device_capability()[0] == 9, (
                "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
969
970
971
972
973
974
975
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

            self.attn_backend = FlashAttentionBackend(self)
976
977
978
979
980
981
        elif self.server_args.attention_backend == "cutlass_mla":
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

            self.attn_backend = CutlassMLABackend(self)
982
983
984
985
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
986

Shuo Yang's avatar
Shuo Yang committed
987
988
989
990
991
992
993
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

994
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
995
996
997
998
999
1000
1001
1002
1003
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1004
    def init_cuda_graphs(self):
applesaucethebun's avatar
applesaucethebun committed
1005
        """Capture CUDA graphs."""
1006
1007
        self.cuda_graph_runner = None

1008
        if not self.is_generation:
applesaucethebun's avatar
applesaucethebun committed
1009
            # TODO: Currently, CUDA graph only captures decode steps, which only exists for generation models
1010
1011
            return

1012
1013
        if self.server_args.disable_cuda_graph:
            return
1014

1015
        tic = time.time()
1016
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1017
        logger.info(
applesaucethebun's avatar
applesaucethebun committed
1018
            f"Capture CUDA graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1019
        )
1020
        self.cuda_graph_runner = CudaGraphRunner(self)
1021
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1022
        logger.info(
applesaucethebun's avatar
applesaucethebun committed
1023
            f"Capture CUDA graph end. Time elapsed: {time.time() - tic:.2f} s. "
1024
            f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
1025
        )
1026

1027
    def apply_torch_tp(self):
1028
1029
        if self.should_log:
            logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1030
1031
1032
1033
1034
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1035
1036
1037
    def forward_decode(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
1038
        self.attn_backend.init_forward_metadata(forward_batch)
1039
1040
1041
1042
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1043
        return self.model.forward(
1044
            forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
Lianmin Zheng's avatar
Lianmin Zheng committed
1045
1046
        )

1047
    def forward_extend(
1048
1049
1050
1051
1052
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
1053
1054
1055
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1069

1070
1071
1072
1073
1074
1075
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
1076
        return self.model.forward(
1077
1078
1079
1080
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
1081
1082
        )

1083
    def forward(
1084
1085
1086
1087
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
1088
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1089
        can_run_cuda_graph = bool(
1090
1091
1092
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
1093
1094
        )
        if can_run_cuda_graph:
1095
            ret = self.cuda_graph_runner.replay(
1096
1097
1098
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1099
            )
1100
1101
        elif forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1102
        elif forward_batch.forward_mode.is_extend():
1103
            ret = self.forward_extend(
1104
1105
1106
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1107
            )
Ke Bao's avatar
Ke Bao committed
1108
        elif forward_batch.forward_mode.is_idle():
1109
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
1110
        else:
1111
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1112

1113
1114
        return ret, can_run_cuda_graph

1115
1116
1117
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1118
        # Apply logit bias
1119
1120
1121
1122
1123
1124
1125
1126
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1127
1128
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
1149

1150
1151
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1152
1153
1154
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1155
            forward_batch.sampling_info,
1156
1157
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1158
            forward_batch.token_ids_logprobs,
1159
        )
1160
1161
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1162
1163
1164
1165
1166
1167
1168
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
1169
1170
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1171

1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1188
1189
1190
1191
1192
1193
1194
1195
1196

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
1197
1198
1199
        monkey_patch_torch_reductions()
        tensor = tensor.get(tp_rank)
    return tensor.to(torch.cuda.current_device())
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])