model_runner.py 49.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
18
import inspect
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
import time
23
24
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.distributed as dist
28
29
30
31
32

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
33
    get_tp_group,
34
    get_world_group,
zhyncs's avatar
zhyncs committed
35
36
    init_distributed_environment,
    initialize_model_parallel,
37
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
38
)
39
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
40
41
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
42
    get_attention_tp_size,
43
44
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
45
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
46
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
47
48
49
50
from sglang.srt.layers.quantization.deep_gemm import (
    _ENABLE_JIT_DEEPGEMM,
    update_deep_gemm_config,
)
51
from sglang.srt.layers.sampler import Sampler
52
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
53
from sglang.srt.lora.lora_manager import LoRAManager
54
55
56
57
58
59
60
61
62
63
from sglang.srt.managers.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
    set_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import (
    compute_initial_expert_location_metadata,
    get_global_expert_location_metadata,
    set_global_expert_location_metadata,
)
64
from sglang.srt.managers.schedule_batch import global_server_args_dict
65
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
66
    DoubleSparseTokenToKVPool,
67
68
69
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
70
    TokenToKVPoolAllocator,
71
)
Lianmin Zheng's avatar
Lianmin Zheng committed
72
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
Yineng Zhang's avatar
Yineng Zhang committed
73
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
74
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
75
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
78
79
80
81
from sglang.srt.model_loader.loader import (
    DefaultModelLoader,
    device_loading_context,
    get_model_loader,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
82
from sglang.srt.model_loader.weight_utils import default_weight_loader
83
from sglang.srt.patch_torch import monkey_patch_torch_reductions
84
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
85
from sglang.srt.server_args import ServerArgs
86
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
87
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
88
from sglang.srt.utils import (
89
    MultiprocessingSerializer,
90
    enable_show_time_cost,
91
    get_available_gpu_memory,
92
    get_bool_env_var,
93
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
94
    is_cuda,
95
    is_fa3_default_architecture,
96
    is_flashinfer_available,
HAI's avatar
HAI committed
97
    is_hip,
98
    is_hopper_with_cuda_12_3,
99
    is_no_spec_infer_or_topk_one,
100
    monkey_patch_p2p_access_check,
101
    monkey_patch_vllm_gguf_config,
102
    set_cpu_offload_max_bytes,
103
    set_cuda_arch,
104
)
105

106
107
_is_hip = is_hip()

Lianmin Zheng's avatar
Lianmin Zheng committed
108
# Use a small KV cache pool size for tests in CI
109
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
110
111

# Detect stragger ranks in model loading
112
113
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
114
115
logger = logging.getLogger(__name__)

116

117
118
119
120
121
122
123
124
125
126
127
128
129
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
130
class ModelRunner:
131
132
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
133
134
    def __init__(
        self,
135
        model_config: ModelConfig,
136
137
138
139
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
140
141
        pp_rank: int,
        pp_size: int,
142
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
143
        server_args: ServerArgs,
144
        is_draft_worker: bool = False,
145
146
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
147
    ):
148
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
149
150
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
151
        self.device = server_args.device
152
        self.gpu_id = gpu_id
153
154
155
156

        # Apply the rank zero filter to logger
        if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
            logger.addFilter(RankZeroFilter(tp_rank == 0))
Lianmin Zheng's avatar
Lianmin Zheng committed
157
158
        self.tp_rank = tp_rank
        self.tp_size = tp_size
159
160
        self.pp_rank = pp_rank
        self.pp_size = pp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
161
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
162
        self.server_args = server_args
163
        self.is_draft_worker = is_draft_worker
164
165
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
166
167
168
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
169
        self.page_size = server_args.page_size
170
171
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Baizhou Zhang's avatar
Baizhou Zhang committed
172
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
173
        self.attention_chunk_size = model_config.attention_chunk_size
Ke Bao's avatar
Ke Bao committed
174

175
176
        self.forward_pass_id = 0

177
        # Model-specific adjustment
178
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
179

180
181
        if server_args.show_time_cost:
            enable_show_time_cost()
182
183

        # Global vars
184
185
        global_server_args_dict.update(
            {
186
                "attention_backend": server_args.attention_backend,
187
188
189
190
191
192
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "deepep_mode": server_args.deepep_mode,
                "device": server_args.device,
                "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
                "disable_radix_cache": server_args.disable_radix_cache,
193
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
194
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
195
                "enable_ep_moe": server_args.enable_ep_moe,
196
                "enable_deepep_moe": server_args.enable_deepep_moe,
197
                "deepep_config": server_args.deepep_config,
198
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
199
                "moe_dense_tp_size": server_args.moe_dense_tp_size,
200
                "ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
201
                "n_share_experts_fusion": server_args.n_share_experts_fusion,
202
203
204
205
206
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
                "torchao_config": server_args.torchao_config,
                "sampling_backend": server_args.sampling_backend,
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
207
                "use_mla_backend": self.use_mla_backend,
208
                "mm_attention_backend": server_args.mm_attention_backend,
209
210
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
211

212
        # CPU offload
213
214
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

215
        # Get memory before model loading
216
        min_per_gpu_memory = self.init_torch_distributed()
217

218
219
220
221
        # Update deep gemm configure
        if _ENABLE_JIT_DEEPGEMM:
            update_deep_gemm_config(gpu_id, server_args)

Lianmin Zheng's avatar
Lianmin Zheng committed
222
        # If it is a draft model, tp_group can be different
223
224
        self.initialize(min_per_gpu_memory)

225
226
227
228
229
        # temporary cached values
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )

230
231
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
232
233
234
235
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        if not self.is_draft_worker:
            set_global_expert_location_metadata(
                compute_initial_expert_location_metadata(server_args, self.model_config)
            )
            if self.tp_rank == 0 and get_bool_env_var(
                "SGLANG_LOG_EXPERT_LOCATION_METADATA"
            ):
                logger.info(
                    f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
                )

            set_global_expert_distribution_recorder(
                ExpertDistributionRecorder.init_new(
                    server_args,
                    get_global_expert_location_metadata(),
                    rank=self.tp_rank,
                )
            )

255
        # Load the model
256
        self.sampler = Sampler()
257
        self.load_model()
258

259
260
261
262
263
264
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(
            self.model, "end_layer", self.model_config.num_hidden_layers
        )
        self.num_effective_layers = self.end_layer - self.start_layer

265
        # Apply torchao quantization
266
267
268
269
270
271
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
272

273
        # Apply torch TP if the model supports it
274
275
276
277
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

278
        # Init lora
279
280
        if server_args.lora_paths is not None:
            self.init_lora_manager()
281
282

        # Init memory pool and attention backends
283
284
        self.init_memory_pool(
            min_per_gpu_memory,
285
            server_args.max_running_requests,
286
287
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
288
289
290
291
292
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
293
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
294
            self.init_attention_backend()
295

James Liu's avatar
James Liu committed
296
297
298
299
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
            self.model.set_eagle3_layers_to_capture()

300
301
302
    def model_specific_adjustment(self):
        server_args = self.server_args

303
        if server_args.attention_backend is None:
304
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
305
306
            Auto select the fastest attention backend.

307
308
309
310
311
312
313
314
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
                2.2 Otherwise, we will use triton backend.
            """

315
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
316
                # MHA architecture
317
                if (
318
                    is_hopper_with_cuda_12_3()
319
320
321
322
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
323
324
                elif _is_hip:
                    server_args.attention_backend = "aiter"
325
326
327
328
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
329
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
330
                # MLA architecture
331
                if is_hopper_with_cuda_12_3():
332
                    server_args.attention_backend = "fa3"
333
334
                else:
                    server_args.attention_backend = "triton"
335
336
337
            logger.info(
                f"Attention backend not set. Use {server_args.attention_backend} backend by default."
            )
338
        elif self.use_mla_backend:
339
            if server_args.device != "cpu":
340
341
342
343
344
                if server_args.attention_backend in [
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
345
                    "cutlass_mla",
346
                ]:
347
348
349
                    logger.info(
                        f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                    )
350
                else:
351
352
353
354
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
355
                raise ValueError("MLA optimization not supported on CPU.")
356

357
358
359
360
361
362
363
364
365
366
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

367
        if server_args.enable_double_sparsity:
368
369
370
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
371
372
373
374
375
376
377
378
379
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
Mick's avatar
Mick committed
380
            self.mem_fraction_static *= 0.90
381
382
383
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model."
            )
Mick's avatar
Mick committed
384
            server_args.chunked_prefill_size = -1
385
386
387
            logger.info(
                "Automatically turn off --chunked-prefill-size for multimodal model."
            )
388

389
390
391
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
        elif self.page_size > 1:
392
            logger.info("Disable chunked prefix cache when page size > 1.")
393
394
395
            server_args.disable_chunked_prefix_cache = True

        if not server_args.disable_chunked_prefix_cache:
396
            logger.info("Chunked prefix cache is turned on.")
397

398
    def init_torch_distributed(self):
399
        logger.info("Init torch distributed begin.")
400

401
402
403
404
405
406
407
408
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
409
410
        if self.device == "cuda":
            backend = "nccl"
411
        elif self.device == "xpu":
412
            backend = "xccl"
413
414
        elif self.device == "hpu":
            backend = "hccl"
415
416
        elif self.device == "cpu":
            backend = "gloo"
417
418
        elif self.device == "npu":
            backend = "hccl"
419

420
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
421
        if not self.server_args.enable_p2p_check:
422
423
            monkey_patch_p2p_access_check()

424
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
425
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
426
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
427
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
428
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
429
430

        if not self.is_draft_worker:
Mick's avatar
Mick committed
431
            # Only initialize the distributed environment on the target model worker.
432
433
            init_distributed_environment(
                backend=backend,
434
435
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
436
437
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
438
                timeout=self.server_args.dist_timeout,
439
            )
440
441
442
443
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
            )
444
445
446
447
448
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
449
                moe_dense_tp_size=self.server_args.moe_dense_tp_size,
450
                pp_size=self.server_args.pp_size,
451
            )
452

453
        min_per_gpu_memory = get_available_gpu_memory(
454
455
456
457
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
458
        )
459
        self.tp_group = get_tp_group()
460
        self.attention_tp_group = get_attention_tp_group()
461

462
        # Check memory for tensor parallelism
463
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
464
        if self.tp_size > 1:
465
            if min_per_gpu_memory < local_gpu_memory * 0.9:
466
467
468
469
470
471
472
473
474
475
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
476

477
478
479
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
480
        return min_per_gpu_memory
481

Lianmin Zheng's avatar
Lianmin Zheng committed
482
    def load_model(self):
483
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
484
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
485
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
486
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
487
488

        # This can reduce thread conflicts and speed up weight loading.
489
490
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
491
492
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
493
494
495
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
496
                self.server_args.dtype = "float16"
497
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
498
499
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
500

501
502
        set_cuda_arch()

503
        # Prepare the model config
504
505
506
507
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
508
509
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
510
511

        # Load the model
512
513
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
514
515
        monkey_patch_isinstance_for_vllm_base_layer()

516
517
518
519
520
521
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
522
        monkey_patch_vllm_parallel_state(reverse=True)
523
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
524

bjmsong's avatar
bjmsong committed
525
526
527
528
529
530
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
531
532
533
534
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
535
536
537
538
539
540
541
542
543
544
545
546
547
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

548
        # Parse other args
549
        self.sliding_window_size = (
550
551
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
552
553
            else None
        )
554
        self.dtype = self.model_config.dtype
555

556
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
557
        logger.info(
558
            f"Load weight end. "
559
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
560
            f"dtype={self.dtype}, "
561
562
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
563
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
564

565
566
567
568
569
570
571
572
573
574
575
576
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

577
578
579
580
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
581
        logger.info(
Chayenne's avatar
Chayenne committed
582
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
583
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
584
585
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
586
        target_device = torch.device(self.device)
587
        self.model_config.model_path = model_path
588
589
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
590
        # Only support DefaultModelLoader for now
591
592
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
593
594
            message = f"Failed to get model loader: {loader}."
            return False, message
595
596
597

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
598
                DefaultModelLoader.Source.init_new(config, self.model)
599
600
601
602
            )
            return iter

        def model_load_weights(model, iter):
603
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
604
605
            return model

606
        with set_default_torch_dtype(self.model_config.dtype):
607
            try:
608
                iter = get_weight_iter(self.model_config)
609
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
610
                message = f"Failed to get weights iterator: {e}."
611
612
613
614
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
615
616
617
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
618
619
                del iter
                gc.collect()
620
                iter = get_weight_iter(self.model_config)
621
622
623
624
625
626
627
628
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

629
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
630
        return True, "Succeeded to update model weights."
631

632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
660
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
725
        return True, "Success"
726

727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

744
745
746
747
748
749
750
751
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
752
            lora_backend=self.server_args.lora_backend,
753
754
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
755
756
757
        )
        logger.info("LoRA manager ready.")

758
    def profile_max_num_token(self, total_gpu_memory: int):
759
        available_gpu_memory = get_available_gpu_memory(
760
761
762
763
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
764
        )
765
        if self.use_mla_backend:
766
767
768
769
770
            num_layers = (
                self.model_config.num_hidden_layers
                if not self.is_draft_worker
                else self.model_config.hf_config.num_nextn_predict_layers
            )
771
772
            # FIXME: pipeline parallelism is not compatible with mla backend
            assert self.pp_size == 1
773
774
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
775
                * num_layers
776
                * torch._utils._element_size(self.kv_cache_dtype)
777
778
779
            )
        else:
            cell_size = (
780
                self.model_config.get_num_kv_heads(get_attention_tp_size())
781
                * self.model_config.head_dim
782
                * self.num_effective_layers
783
                * 2
784
                * torch._utils._element_size(self.kv_cache_dtype)
785
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
786
787
788
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
789
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
790
791
        return max_num_token

792
    def init_memory_pool(
793
794
        self,
        total_gpu_memory: int,
795
796
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
797
    ):
798
799
800
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
801
            if _is_hip:  # Using natively supported format
HAI's avatar
HAI committed
802
803
804
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
805
806
807
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
808
809
810
811
812
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

813
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
814
815
816
817
818
819
820
821
822
823
824
825

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

826
827
828
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

829
830
831
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
832
                max_num_reqs = self.server_args.max_num_reqs
833
            else:
834
835
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
836
837
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
838
839
840
841
842
843
844
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
845
846
                    + 100
                )
847
848
849
850
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
851

852
853
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
854
                logging.warning(
855
856
857
858
859
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
860

861
862
863
864
865
866
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )

867
        if self.max_total_num_tokens <= 0:
868
            raise RuntimeError(
869
                "Not enough memory. Please try to increase --mem-fraction-static."
870
            )
871

872
873
874
875
876
877
878
879
880
881
882
        if self.req_to_token_pool is None:
            self.req_to_token_pool = ReqToTokenPool(
                size=max_num_reqs + 1,
                max_context_len=self.model_config.context_len + 4,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

883
        if self.use_mla_backend:
884
885
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
886
                page_size=self.page_size,
887
                dtype=self.kv_cache_dtype,
888
889
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
890
891
892
893
                layer_num=(
                    self.model_config.num_hidden_layers
                    if not self.is_draft_worker
                    else self.model_config.hf_config.num_nextn_predict_layers
894
                ),  # PP is not compatible with mla backend
Zhang, Liangang's avatar
Zhang, Liangang committed
895
                device=self.device,
896
                enable_memory_saver=self.server_args.enable_memory_saver,
897
898
                start_layer=self.start_layer,
                end_layer=self.end_layer,
899
            )
Shuo Yang's avatar
Shuo Yang committed
900
901
902
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
903
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
904
                dtype=self.kv_cache_dtype,
905
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
906
                head_dim=self.model_config.head_dim,
907
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
908
909
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
910
                enable_memory_saver=self.server_args.enable_memory_saver,
911
912
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
913
            )
914
915
916
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
917
                page_size=self.page_size,
918
                dtype=self.kv_cache_dtype,
919
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
920
                head_dim=self.model_config.head_dim,
921
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
922
                device=self.device,
923
                enable_memory_saver=self.server_args.enable_memory_saver,
924
925
                start_layer=self.start_layer,
                end_layer=self.end_layer,
926
            )
927
928

        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
            if self.page_size == 1:
                self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
            else:
                self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
944
945
946
        else:
            assert self.is_draft_worker

947
        logger.info(
948
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
949
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
950
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
951

Lianmin Zheng's avatar
Lianmin Zheng committed
952
953
954
955
956
957
958
959
960
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

961
962
    def init_attention_backend(self):
        """Init attention kernel backend."""
963
        if self.server_args.attention_backend == "flashinfer":
964
965
966
967
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
968

969
970
971
972
973
974
975
976
977
978
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
                    self.plan_stream_for_flashinfer = torch.cuda.Stream()
                self.attn_backend = FlashInferAttnBackend(self)
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

                self.attn_backend = FlashInferMLAAttnBackend(self)
979
980
981
982
        elif self.server_args.attention_backend == "aiter":
            from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend

            self.attn_backend = AiterAttnBackend(self)
983
984
985
986
987
988
989
990
991
992
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
993
994
995
996
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

997
                self.attn_backend = DoubleSparseAttnBackend(self)
998
            else:
999
1000
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

1001
1002
                self.attn_backend = TritonAttnBackend(self)
        elif self.server_args.attention_backend == "torch_native":
1003
1004
1005
1006
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

1007
            self.attn_backend = TorchNativeAttnBackend(self)
lukec's avatar
lukec committed
1008
1009
1010
1011
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

            self.attn_backend = FlashMLABackend(self)
1012
        elif self.server_args.attention_backend == "fa3":
1013
1014
1015
1016
            assert (
                torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
            ) or torch.cuda.get_device_capability()[0] == 9, (
                "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1017
1018
1019
1020
1021
1022
1023
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

            self.attn_backend = FlashAttentionBackend(self)
1024
1025
1026
1027
1028
1029
        elif self.server_args.attention_backend == "cutlass_mla":
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

            self.attn_backend = CutlassMLABackend(self)
1030
1031
1032
1033
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
1034

Shuo Yang's avatar
Shuo Yang committed
1035
1036
1037
1038
1039
1040
1041
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1042
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1043
1044
1045
1046
1047
1048
1049
1050
1051
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1052
    def init_cuda_graphs(self):
1053
        """Capture cuda graphs."""
1054
1055
        self.cuda_graph_runner = None

1056
        if not self.is_generation:
1057
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1058
1059
            return

1060
1061
        if self.server_args.disable_cuda_graph:
            return
1062

1063
        tic = time.perf_counter()
1064
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1065
        logger.info(
1066
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1067
        )
1068
        self.cuda_graph_runner = CudaGraphRunner(self)
1069
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1070
        logger.info(
1071
            f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1072
            f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
1073
        )
1074

1075
    def apply_torch_tp(self):
1076
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1077
1078
1079
1080
1081
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1082
1083
1084
    def forward_decode(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
1085
        self.attn_backend.init_forward_metadata(forward_batch)
1086
1087
1088
1089
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1090
        return self.model.forward(
1091
            forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
Lianmin Zheng's avatar
Lianmin Zheng committed
1092
1093
        )

1094
    def forward_extend(
1095
1096
1097
1098
1099
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
1100
1101
1102
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1116

1117
1118
1119
1120
1121
1122
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
1123
        return self.model.forward(
1124
1125
1126
1127
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
1128
1129
        )

1130
    def forward(
1131
1132
1133
1134
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
        self.forward_pass_id += 1

        with get_global_expert_distribution_recorder().with_forward_pass(
            self.forward_pass_id,
            forward_batch,
        ):
            return self._forward_raw(
                forward_batch, skip_attn_backend_init, pp_proxy_tensors
            )

    def _forward_raw(
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool,
        pp_proxy_tensors: Optional[PPProxyTensors],
1151
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1152
        can_run_cuda_graph = bool(
1153
1154
1155
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
1156
1157
        )
        if can_run_cuda_graph:
1158
            ret = self.cuda_graph_runner.replay(
1159
1160
1161
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1162
            )
1163
1164
        elif forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1165
        elif forward_batch.forward_mode.is_extend():
1166
            ret = self.forward_extend(
1167
1168
1169
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1170
            )
Ke Bao's avatar
Ke Bao committed
1171
        elif forward_batch.forward_mode.is_idle():
1172
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
1173
        else:
1174
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1175

1176
1177
        return ret, can_run_cuda_graph

1178
1179
1180
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1181
        # Apply logit bias
1182
1183
1184
1185
1186
1187
1188
1189
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1190
1191
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
1212

1213
1214
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1215
1216
1217
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1218
            forward_batch.sampling_info,
1219
1220
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1221
            forward_batch.token_ids_logprobs,
1222
        )
1223
1224
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1225
1226
1227
1228
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
1229
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
1230
1231
        if rope_scaling is None:
            return False
1232
1233
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1234

1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1251
1252
1253
1254
1255
1256
1257
1258
1259

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
1260
1261
1262
        monkey_patch_torch_reductions()
        tensor = tensor.get(tp_rank)
    return tensor.to(torch.cuda.current_device())
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])