model_runner.py 47.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import collections
17
import datetime
18
import gc
19
import inspect
Shuo Yang's avatar
Shuo Yang committed
20
import json
21
import logging
22
import os
23
import time
24
25
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27

import torch
28
import torch.distributed as dist
29
30
31
32
33

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
34
    get_tp_group,
35
    get_world_group,
zhyncs's avatar
zhyncs committed
36
37
    init_distributed_environment,
    initialize_model_parallel,
38
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
39
)
40
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
41
42
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
43
    get_attention_tp_size,
44
45
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
46
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
47
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
48
49
50
51
from sglang.srt.layers.quantization.deep_gemm import (
    _ENABLE_JIT_DEEPGEMM,
    update_deep_gemm_config,
)
52
from sglang.srt.layers.sampler import Sampler
53
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
54
from sglang.srt.lora.lora_manager import LoRAManager
55
from sglang.srt.managers.schedule_batch import global_server_args_dict
56
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
57
    DoubleSparseTokenToKVPool,
58
59
60
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
61
    TokenToKVPoolAllocator,
62
)
Lianmin Zheng's avatar
Lianmin Zheng committed
63
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
Yineng Zhang's avatar
Yineng Zhang committed
64
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
65
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69
70
71
72
from sglang.srt.model_loader.loader import (
    DefaultModelLoader,
    device_loading_context,
    get_model_loader,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
73
from sglang.srt.model_loader.weight_utils import default_weight_loader
74
from sglang.srt.patch_torch import monkey_patch_torch_reductions
75
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
76
from sglang.srt.server_args import ServerArgs
77
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
78
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
79
from sglang.srt.utils import (
80
    MultiprocessingSerializer,
81
    enable_show_time_cost,
82
    get_available_gpu_memory,
83
    get_bool_env_var,
84
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
85
    is_cuda,
86
    is_fa3_default_architecture,
87
    is_flashinfer_available,
HAI's avatar
HAI committed
88
    is_hip,
89
    is_hopper_with_cuda_12_3,
90
    is_no_spec_infer_or_topk_one,
91
    monkey_patch_p2p_access_check,
92
    monkey_patch_vllm_gguf_config,
93
    set_cpu_offload_max_bytes,
94
    set_cuda_arch,
95
)
96

Lianmin Zheng's avatar
Lianmin Zheng committed
97
# Use a small KV cache pool size for tests in CI
98
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
99
100

# Detect stragger ranks in model loading
101
102
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
103
104
logger = logging.getLogger(__name__)

105

106
107
108
109
110
111
112
113
114
115
116
117
118
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
119
class ModelRunner:
120
121
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
    def __init__(
        self,
124
        model_config: ModelConfig,
125
126
127
128
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
129
130
        pp_rank: int,
        pp_size: int,
131
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
132
        server_args: ServerArgs,
133
        is_draft_worker: bool = False,
134
135
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
136
    ):
137
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
140
        self.device = server_args.device
141
        self.gpu_id = gpu_id
142
143
144
145

        # Apply the rank zero filter to logger
        if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
            logger.addFilter(RankZeroFilter(tp_rank == 0))
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147
        self.tp_rank = tp_rank
        self.tp_size = tp_size
148
149
        self.pp_rank = pp_rank
        self.pp_size = pp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
150
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
151
        self.server_args = server_args
152
        self.is_draft_worker = is_draft_worker
153
154
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
155
156
157
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
158
        self.page_size = server_args.page_size
159
160
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Baizhou Zhang's avatar
Baizhou Zhang committed
161
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
162
        self.attention_chunk_size = model_config.attention_chunk_size
Ke Bao's avatar
Ke Bao committed
163

164
        # Model-specific adjustment
165
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
166

167
168
        if server_args.show_time_cost:
            enable_show_time_cost()
169
170

        # Global vars
171
172
        global_server_args_dict.update(
            {
173
                "attention_backend": server_args.attention_backend,
174
175
176
177
178
179
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "deepep_mode": server_args.deepep_mode,
                "device": server_args.device,
                "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
                "disable_radix_cache": server_args.disable_radix_cache,
180
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
181
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
182
                "enable_ep_moe": server_args.enable_ep_moe,
183
                "enable_deepep_moe": server_args.enable_deepep_moe,
184
                "deepep_config": server_args.deepep_config,
185
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
186
                "moe_dense_tp_size": server_args.moe_dense_tp_size,
187
                "n_share_experts_fusion": server_args.n_share_experts_fusion,
188
189
190
191
192
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
                "torchao_config": server_args.torchao_config,
                "sampling_backend": server_args.sampling_backend,
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
193
                "use_mla_backend": self.use_mla_backend,
194
                "mm_attention_backend": server_args.mm_attention_backend,
195
196
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
197

198
        # CPU offload
199
200
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

201
        # Get memory before model loading
202
        min_per_gpu_memory = self.init_torch_distributed()
203

204
205
206
207
        # Update deep gemm configure
        if _ENABLE_JIT_DEEPGEMM:
            update_deep_gemm_config(gpu_id, server_args)

Lianmin Zheng's avatar
Lianmin Zheng committed
208
        # If it is a draft model, tp_group can be different
209
210
        self.initialize(min_per_gpu_memory)

211
212
213
214
215
        # temporary cached values
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )

216
217
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
218
219
220
221
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

222
        # Load the model
223
        self.sampler = Sampler()
224
        self.load_model()
225

226
227
228
229
230
231
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(
            self.model, "end_layer", self.model_config.num_hidden_layers
        )
        self.num_effective_layers = self.end_layer - self.start_layer

232
        # Apply torchao quantization
233
234
235
236
237
238
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
239

240
        # Apply torch TP if the model supports it
241
242
243
244
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

245
        # Init lora
246
247
        if server_args.lora_paths is not None:
            self.init_lora_manager()
248
249

        # Init memory pool and attention backends
250
251
        self.init_memory_pool(
            min_per_gpu_memory,
252
            server_args.max_running_requests,
253
254
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
255
256
257
258
259
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
260
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
261
            self.init_attention_backend()
262

James Liu's avatar
James Liu committed
263
264
265
266
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
            self.model.set_eagle3_layers_to_capture()

267
268
269
    def model_specific_adjustment(self):
        server_args = self.server_args

270
        if server_args.attention_backend is None:
271
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
272
273
            Auto select the fastest attention backend.

274
275
276
277
278
279
280
281
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
                2.2 Otherwise, we will use triton backend.
            """

282
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
283
                # MHA architecture
284
                if (
285
                    is_hopper_with_cuda_12_3()
286
287
288
289
290
291
292
293
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
294
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
295
                # MLA architecture
296
                if is_hopper_with_cuda_12_3():
297
                    server_args.attention_backend = "fa3"
298
299
                else:
                    server_args.attention_backend = "triton"
300
301
302
            logger.info(
                f"Attention backend not set. Use {server_args.attention_backend} backend by default."
            )
303
        elif self.use_mla_backend:
304
            if server_args.device != "cpu":
305
306
307
308
309
                if server_args.attention_backend in [
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
310
                    "cutlass_mla",
311
                ]:
312
313
314
                    logger.info(
                        f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                    )
315
                else:
316
317
318
319
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
320
                raise ValueError("MLA optimization not supported on CPU.")
321

322
323
324
325
326
327
328
329
330
331
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

332
        if server_args.enable_double_sparsity:
333
334
335
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
336
337
338
339
340
341
342
343
344
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
Mick's avatar
Mick committed
345
            self.mem_fraction_static *= 0.90
346
347
348
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model."
            )
Mick's avatar
Mick committed
349
            server_args.chunked_prefill_size = -1
350
351
352
            logger.info(
                "Automatically turn off --chunked-prefill-size for multimodal model."
            )
353

354
355
356
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
        elif self.page_size > 1:
357
            logger.info("Disable chunked prefix cache when page size > 1.")
358
359
360
            server_args.disable_chunked_prefix_cache = True

        if not server_args.disable_chunked_prefix_cache:
361
            logger.info("Chunked prefix cache is turned on.")
362

363
    def init_torch_distributed(self):
364
        logger.info("Init torch distributed begin.")
365

366
367
368
369
370
371
372
373
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
374
375
        if self.device == "cuda":
            backend = "nccl"
376
        elif self.device == "xpu":
377
            backend = "xccl"
378
379
        elif self.device == "hpu":
            backend = "hccl"
380
381
        elif self.device == "cpu":
            backend = "gloo"
382
383
        elif self.device == "npu":
            backend = "hccl"
384

385
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
386
        if not self.server_args.enable_p2p_check:
387
388
            monkey_patch_p2p_access_check()

389
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
390
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
391
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
392
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
393
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
394
395

        if not self.is_draft_worker:
Mick's avatar
Mick committed
396
            # Only initialize the distributed environment on the target model worker.
397
398
            init_distributed_environment(
                backend=backend,
399
400
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
401
402
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
403
                timeout=self.server_args.dist_timeout,
404
            )
405
406
407
408
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
            )
409
410
411
412
413
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
414
                moe_dense_tp_size=self.server_args.moe_dense_tp_size,
415
                pp_size=self.server_args.pp_size,
416
            )
417

418
        min_per_gpu_memory = get_available_gpu_memory(
419
420
421
422
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
423
        )
424
        self.tp_group = get_tp_group()
425
        self.attention_tp_group = get_attention_tp_group()
426

427
        # Check memory for tensor parallelism
428
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
429
        if self.tp_size > 1:
430
            if min_per_gpu_memory < local_gpu_memory * 0.9:
431
432
433
434
435
436
437
438
439
440
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
441

442
443
444
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
445
        return min_per_gpu_memory
446

Lianmin Zheng's avatar
Lianmin Zheng committed
447
    def load_model(self):
448
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
449
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
450
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
451
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
452
453

        # This can reduce thread conflicts and speed up weight loading.
454
455
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
456
457
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
458
459
460
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
461
                self.server_args.dtype = "float16"
462
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
463
464
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
465

466
467
        set_cuda_arch()

468
        # Prepare the model config
469
470
471
472
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
473
474
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
475
476

        # Load the model
477
478
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
479
480
        monkey_patch_isinstance_for_vllm_base_layer()

481
482
483
484
485
486
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
487
        monkey_patch_vllm_parallel_state(reverse=True)
488
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
489

bjmsong's avatar
bjmsong committed
490
491
492
493
494
495
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
496
497
498
499
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
500
501
502
503
504
505
506
507
508
509
510
511
512
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

513
        # Parse other args
514
        self.sliding_window_size = (
515
516
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
517
518
            else None
        )
519
        self.dtype = self.model_config.dtype
520

521
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
522
        logger.info(
523
            f"Load weight end. "
524
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
525
            f"dtype={self.dtype}, "
526
527
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
528
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
529

530
531
532
533
534
535
536
537
538
539
540
541
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

542
543
544
545
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
546
        logger.info(
Chayenne's avatar
Chayenne committed
547
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
548
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
549
550
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
551
        target_device = torch.device(self.device)
552
        self.model_config.model_path = model_path
553
554
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
555
        # Only support DefaultModelLoader for now
556
557
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
558
559
            message = f"Failed to get model loader: {loader}."
            return False, message
560
561
562

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
563
                DefaultModelLoader.Source.init_new(config, self.model)
564
565
566
567
            )
            return iter

        def model_load_weights(model, iter):
568
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
569
570
            return model

571
        with set_default_torch_dtype(self.model_config.dtype):
572
            try:
573
                iter = get_weight_iter(self.model_config)
574
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
575
                message = f"Failed to get weights iterator: {e}."
576
577
578
579
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
580
581
582
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
583
584
                del iter
                gc.collect()
585
                iter = get_weight_iter(self.model_config)
586
587
588
589
590
591
592
593
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

594
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
595
        return True, "Succeeded to update model weights."
596

597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
625
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
690
        return True, "Success"
691

692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

709
710
711
712
713
714
715
716
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
717
            lora_backend=self.server_args.lora_backend,
718
719
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
720
721
722
        )
        logger.info("LoRA manager ready.")

723
    def profile_max_num_token(self, total_gpu_memory: int):
724
        available_gpu_memory = get_available_gpu_memory(
725
726
727
728
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
729
        )
730
        if self.use_mla_backend:
731
732
733
734
735
            num_layers = (
                self.model_config.num_hidden_layers
                if not self.is_draft_worker
                else self.model_config.hf_config.num_nextn_predict_layers
            )
736
737
            # FIXME: pipeline parallelism is not compatible with mla backend
            assert self.pp_size == 1
738
739
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
740
                * num_layers
741
                * torch._utils._element_size(self.kv_cache_dtype)
742
743
744
            )
        else:
            cell_size = (
745
                self.model_config.get_num_kv_heads(get_attention_tp_size())
746
                * self.model_config.head_dim
747
                * self.num_effective_layers
748
                * 2
749
                * torch._utils._element_size(self.kv_cache_dtype)
750
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
751
752
753
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
754
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
755
756
        return max_num_token

757
    def init_memory_pool(
758
759
        self,
        total_gpu_memory: int,
760
761
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
762
    ):
763
764
765
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
766
            if is_hip():  # Using natively supported format
HAI's avatar
HAI committed
767
768
769
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
770
771
772
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
773
774
775
776
777
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

778
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
779
780
781
782
783
784
785
786
787
788
789
790

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

791
792
793
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

794
795
796
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
797
                max_num_reqs = self.server_args.max_num_reqs
798
            else:
799
800
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
801
802
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
803
804
805
806
807
808
809
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
810
811
                    + 100
                )
812
813
814
815
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
816

817
818
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
819
                logging.warning(
820
821
822
823
824
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
825

826
827
828
829
830
831
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )

832
        if self.max_total_num_tokens <= 0:
833
            raise RuntimeError(
834
                "Not enough memory. Please try to increase --mem-fraction-static."
835
            )
836

837
838
839
840
841
842
843
844
845
846
847
        if self.req_to_token_pool is None:
            self.req_to_token_pool = ReqToTokenPool(
                size=max_num_reqs + 1,
                max_context_len=self.model_config.context_len + 4,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

848
        if self.use_mla_backend:
849
850
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
851
                page_size=self.page_size,
852
                dtype=self.kv_cache_dtype,
853
854
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
855
856
857
858
                layer_num=(
                    self.model_config.num_hidden_layers
                    if not self.is_draft_worker
                    else self.model_config.hf_config.num_nextn_predict_layers
859
                ),  # PP is not compatible with mla backend
Zhang, Liangang's avatar
Zhang, Liangang committed
860
                device=self.device,
861
                enable_memory_saver=self.server_args.enable_memory_saver,
862
863
                start_layer=self.start_layer,
                end_layer=self.end_layer,
864
            )
Shuo Yang's avatar
Shuo Yang committed
865
866
867
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
868
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
869
                dtype=self.kv_cache_dtype,
870
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
871
                head_dim=self.model_config.head_dim,
872
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
873
874
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
875
                enable_memory_saver=self.server_args.enable_memory_saver,
876
877
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
878
            )
879
880
881
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
882
                page_size=self.page_size,
883
                dtype=self.kv_cache_dtype,
884
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
885
                head_dim=self.model_config.head_dim,
886
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
887
                device=self.device,
888
                enable_memory_saver=self.server_args.enable_memory_saver,
889
890
                start_layer=self.start_layer,
                end_layer=self.end_layer,
891
            )
892
893

        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
            if self.page_size == 1:
                self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
            else:
                self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
909
910
911
        else:
            assert self.is_draft_worker

912
        logger.info(
913
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
914
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
915
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
916

Lianmin Zheng's avatar
Lianmin Zheng committed
917
918
919
920
921
922
923
924
925
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

926
927
    def init_attention_backend(self):
        """Init attention kernel backend."""
928
        if self.server_args.attention_backend == "flashinfer":
929
930
931
932
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
933

934
935
936
937
938
939
940
941
942
943
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
                    self.plan_stream_for_flashinfer = torch.cuda.Stream()
                self.attn_backend = FlashInferAttnBackend(self)
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

                self.attn_backend = FlashInferMLAAttnBackend(self)
944
945
946
947
948
949
950
951
952
953
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
954
955
956
957
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

958
                self.attn_backend = DoubleSparseAttnBackend(self)
959
            else:
960
961
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

962
963
                self.attn_backend = TritonAttnBackend(self)
        elif self.server_args.attention_backend == "torch_native":
964
965
966
967
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

968
            self.attn_backend = TorchNativeAttnBackend(self)
lukec's avatar
lukec committed
969
970
971
972
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

            self.attn_backend = FlashMLABackend(self)
973
        elif self.server_args.attention_backend == "fa3":
974
975
976
977
            assert (
                torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
            ) or torch.cuda.get_device_capability()[0] == 9, (
                "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
978
979
980
981
982
983
984
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

            self.attn_backend = FlashAttentionBackend(self)
985
986
987
988
989
990
        elif self.server_args.attention_backend == "cutlass_mla":
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

            self.attn_backend = CutlassMLABackend(self)
991
992
993
994
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
995

Shuo Yang's avatar
Shuo Yang committed
996
997
998
999
1000
1001
1002
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1003
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1004
1005
1006
1007
1008
1009
1010
1011
1012
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1013
    def init_cuda_graphs(self):
1014
        """Capture cuda graphs."""
1015
1016
        self.cuda_graph_runner = None

1017
        if not self.is_generation:
1018
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1019
1020
            return

1021
1022
        if self.server_args.disable_cuda_graph:
            return
1023

1024
        tic = time.perf_counter()
1025
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1026
        logger.info(
1027
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1028
        )
1029
        self.cuda_graph_runner = CudaGraphRunner(self)
1030
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1031
        logger.info(
1032
            f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1033
            f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
1034
        )
1035

1036
    def apply_torch_tp(self):
1037
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1038
1039
1040
1041
1042
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1043
1044
1045
    def forward_decode(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
1046
        self.attn_backend.init_forward_metadata(forward_batch)
1047
1048
1049
1050
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1051
        return self.model.forward(
1052
            forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
Lianmin Zheng's avatar
Lianmin Zheng committed
1053
1054
        )

1055
    def forward_extend(
1056
1057
1058
1059
1060
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
1061
1062
1063
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1077

1078
1079
1080
1081
1082
1083
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
1084
        return self.model.forward(
1085
1086
1087
1088
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
1089
1090
        )

1091
    def forward(
1092
1093
1094
1095
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
1096
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1097
        can_run_cuda_graph = bool(
1098
1099
1100
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
1101
1102
        )
        if can_run_cuda_graph:
1103
            ret = self.cuda_graph_runner.replay(
1104
1105
1106
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1107
            )
1108
1109
        elif forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1110
        elif forward_batch.forward_mode.is_extend():
1111
            ret = self.forward_extend(
1112
1113
1114
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1115
            )
Ke Bao's avatar
Ke Bao committed
1116
        elif forward_batch.forward_mode.is_idle():
1117
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
1118
        else:
1119
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1120

1121
1122
        return ret, can_run_cuda_graph

1123
1124
1125
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1126
        # Apply logit bias
1127
1128
1129
1130
1131
1132
1133
1134
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1135
1136
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
1157

1158
1159
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1160
1161
1162
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1163
            forward_batch.sampling_info,
1164
1165
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1166
            forward_batch.token_ids_logprobs,
1167
        )
1168
1169
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1170
1171
1172
1173
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
1174
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
1175
1176
        if rope_scaling is None:
            return False
1177
1178
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1179

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1196
1197
1198
1199
1200
1201
1202
1203
1204

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
1205
1206
1207
        monkey_patch_torch_reductions()
        tensor = tensor.get(tp_rank)
    return tensor.to(torch.cuda.current_device())
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])