model_runner.py 52.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
18
import inspect
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
import time
23
24
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.distributed as dist
28
29
30
31
32

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
33
    get_tp_group,
34
    get_world_group,
zhyncs's avatar
zhyncs committed
35
36
    init_distributed_environment,
    initialize_model_parallel,
37
    set_custom_all_reduce,
38
    set_mscclpp_all_reduce,
zhyncs's avatar
zhyncs committed
39
)
40
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
41
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
42
43
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
44
    get_attention_tp_size,
45
46
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
47
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
49
50
51
52
from sglang.srt.layers.quantization.deep_gemm import (
    _ENABLE_JIT_DEEPGEMM,
    update_deep_gemm_config,
)
53
from sglang.srt.layers.sampler import Sampler
54
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
55
from sglang.srt.lora.lora_manager import LoRAManager
56
from sglang.srt.managers.eplb_manager import EPLBManager
57
58
59
60
61
62
from sglang.srt.managers.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
    set_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import (
63
    ExpertLocationMetadata,
64
65
66
67
    compute_initial_expert_location_metadata,
    get_global_expert_location_metadata,
    set_global_expert_location_metadata,
)
68
from sglang.srt.managers.schedule_batch import global_server_args_dict
69
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
70
    DoubleSparseTokenToKVPool,
71
72
73
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
74
    TokenToKVPoolAllocator,
75
)
Lianmin Zheng's avatar
Lianmin Zheng committed
76
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
Yineng Zhang's avatar
Yineng Zhang committed
77
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
78
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
79
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
80
from sglang.srt.model_loader import get_model
81
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
Lianmin Zheng's avatar
Lianmin Zheng committed
82
from sglang.srt.model_loader.utils import set_default_torch_dtype
83
from sglang.srt.model_loader.weight_utils import default_weight_loader
84
from sglang.srt.patch_torch import monkey_patch_torch_reductions
85
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
86
from sglang.srt.server_args import ServerArgs
87
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
88
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
89
from sglang.srt.utils import (
90
    MultiprocessingSerializer,
91
    cpu_has_amx_support,
92
    enable_show_time_cost,
93
    get_available_gpu_memory,
94
    get_bool_env_var,
95
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
96
    is_cuda,
97
    is_fa3_default_architecture,
98
    is_flashinfer_available,
HAI's avatar
HAI committed
99
    is_hip,
100
    is_hopper_with_cuda_12_3,
101
    is_no_spec_infer_or_topk_one,
102
    monkey_patch_p2p_access_check,
103
    monkey_patch_vllm_gguf_config,
104
    set_cpu_offload_max_bytes,
105
    set_cuda_arch,
106
)
107

108
109
_is_hip = is_hip()

Lianmin Zheng's avatar
Lianmin Zheng committed
110
# Use a small KV cache pool size for tests in CI
111
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
112
113

# Detect stragger ranks in model loading
114
115
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
logger = logging.getLogger(__name__)

118

119
120
121
122
123
124
125
126
127
128
129
130
131
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
132
class ModelRunner:
133
134
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
    def __init__(
        self,
137
        model_config: ModelConfig,
138
139
140
141
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
142
143
        pp_rank: int,
        pp_size: int,
144
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
145
        server_args: ServerArgs,
146
        is_draft_worker: bool = False,
147
148
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
149
    ):
150
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
151
152
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
153
        self.device = server_args.device
154
        self.gpu_id = gpu_id
155
156
157
158

        # Apply the rank zero filter to logger
        if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
            logger.addFilter(RankZeroFilter(tp_rank == 0))
Lianmin Zheng's avatar
Lianmin Zheng committed
159
160
        self.tp_rank = tp_rank
        self.tp_size = tp_size
161
162
        self.pp_rank = pp_rank
        self.pp_size = pp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
163
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
164
        self.server_args = server_args
165
        self.is_draft_worker = is_draft_worker
166
167
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
168
169
170
        self.is_multimodal_chunked_prefill_supported = (
            model_config.is_multimodal_chunked_prefill_supported
        )
171
172
173
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
174
        self.page_size = server_args.page_size
175
176
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Baizhou Zhang's avatar
Baizhou Zhang committed
177
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
178
        self.attention_chunk_size = model_config.attention_chunk_size
Ke Bao's avatar
Ke Bao committed
179

180
181
        self.forward_pass_id = 0

182
        # Model-specific adjustment
183
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
184

185
186
        if server_args.show_time_cost:
            enable_show_time_cost()
187
188

        # Global vars
189
190
        global_server_args_dict.update(
            {
191
                "attention_backend": server_args.attention_backend,
192
193
194
195
196
197
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "deepep_mode": server_args.deepep_mode,
                "device": server_args.device,
                "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
                "disable_radix_cache": server_args.disable_radix_cache,
198
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
199
                "enable_dp_attention": server_args.enable_dp_attention,
200
                "enable_two_batch_overlap": server_args.enable_two_batch_overlap,
201
                "enable_dp_lm_head": server_args.enable_dp_lm_head,
xiaobochen's avatar
xiaobochen committed
202
                "enable_ep_moe": server_args.enable_ep_moe,
203
                "enable_deepep_moe": server_args.enable_deepep_moe,
204
                "deepep_config": server_args.deepep_config,
205
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
206
                "moe_dense_tp_size": server_args.moe_dense_tp_size,
207
                "ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
208
                "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
209
210
211
212
213
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
                "torchao_config": server_args.torchao_config,
                "sampling_backend": server_args.sampling_backend,
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
214
                "use_mla_backend": self.use_mla_backend,
215
                "mm_attention_backend": server_args.mm_attention_backend,
216
                "ep_num_redundant_experts": server_args.ep_num_redundant_experts,
217
218
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
219

220
        # CPU offload
221
222
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

223
        # Get memory before model loading
224
        min_per_gpu_memory = self.init_torch_distributed()
225

226
227
228
229
        # Update deep gemm configure
        if _ENABLE_JIT_DEEPGEMM:
            update_deep_gemm_config(gpu_id, server_args)

Lianmin Zheng's avatar
Lianmin Zheng committed
230
        # If it is a draft model, tp_group can be different
231
232
        self.initialize(min_per_gpu_memory)

233
234
235
236
237
        # temporary cached values
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )

238
239
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
240
241
242
243
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        if not self.is_draft_worker:
            set_global_expert_location_metadata(
                compute_initial_expert_location_metadata(server_args, self.model_config)
            )
            if self.tp_rank == 0 and get_bool_env_var(
                "SGLANG_LOG_EXPERT_LOCATION_METADATA"
            ):
                logger.info(
                    f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
                )

            set_global_expert_distribution_recorder(
                ExpertDistributionRecorder.init_new(
                    server_args,
                    get_global_expert_location_metadata(),
                    rank=self.tp_rank,
                )
            )

263
264
265
266
267
        self.eplb_manager = (
            EPLBManager(self)
            if self.server_args.enable_eplb and (not self.is_draft_worker)
            else None
        )
268
        self.expert_location_updater = ExpertLocationUpdater()
269

270
        # Load the model
271
        self.sampler = Sampler()
272
        self.load_model()
273

274
275
276
277
278
279
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(
            self.model, "end_layer", self.model_config.num_hidden_layers
        )
        self.num_effective_layers = self.end_layer - self.start_layer

280
        # Apply torchao quantization
281
282
283
284
285
286
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
287

288
        # Apply torch TP if the model supports it
289
290
291
292
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

293
        # Init lora
294
295
        if server_args.lora_paths is not None:
            self.init_lora_manager()
296
297

        # Init memory pool and attention backends
298
299
        self.init_memory_pool(
            min_per_gpu_memory,
300
            server_args.max_running_requests,
301
302
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
303
304
305
306
307
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
308
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
309
            self.init_attention_backend()
310

James Liu's avatar
James Liu committed
311
312
313
314
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
            self.model.set_eagle3_layers_to_capture()

315
316
317
    def model_specific_adjustment(self):
        server_args = self.server_args

318
319
320
321
322
323
324
325
326
327
        if (
            server_args.attention_backend == "intel_amx"
            and server_args.device == "cpu"
            and not cpu_has_amx_support()
        ):
            logger.info(
                "The current platform does not support Intel AMX, will fallback to torch_native backend."
            )
            server_args.attention_backend = "torch_native"

328
        if server_args.attention_backend is None:
329
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
330
331
            Auto select the fastest attention backend.

332
333
334
335
336
337
338
339
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
                2.2 Otherwise, we will use triton backend.
            """

340
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
341
                # MHA architecture
342
                if (
343
                    is_hopper_with_cuda_12_3()
344
345
346
347
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
348
349
                elif _is_hip:
                    server_args.attention_backend = "aiter"
350
351
352
353
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
354
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
355
                # MLA architecture
356
                if is_hopper_with_cuda_12_3():
357
                    server_args.attention_backend = "fa3"
358
359
                else:
                    server_args.attention_backend = "triton"
360
361
362
            logger.info(
                f"Attention backend not set. Use {server_args.attention_backend} backend by default."
            )
363
        elif self.use_mla_backend:
364
            if server_args.device != "cpu":
365
366
367
368
369
                if server_args.attention_backend in [
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
370
                    "cutlass_mla",
371
                ]:
372
373
374
                    logger.info(
                        f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                    )
375
                else:
376
377
378
379
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
380
381
382
383
                if server_args.attention_backend != "intel_amx":
                    raise ValueError(
                        "MLA optimization not supported on CPU except for intel_amx backend."
                    )
384

385
386
387
388
389
390
391
392
393
394
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

395
        if server_args.enable_double_sparsity:
396
397
398
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
399
400
401
402
403
404
405
406
407
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
Mick's avatar
Mick committed
408
            self.mem_fraction_static *= 0.90
409
            logger.info(
410
411
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
412
            )
413
414
415
416
417
418
            if not self.is_multimodal_chunked_prefill_supported:
                server_args.chunked_prefill_size = -1
                logger.info(
                    f"Automatically turn of --chunked-prefill-size as it is not supported for "
                    f"{self.model_config.hf_config.model_type}"
                )
419

420
421
422
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
        elif self.page_size > 1:
423
            logger.info("Disable chunked prefix cache when page size > 1.")
424
425
426
            server_args.disable_chunked_prefix_cache = True

        if not server_args.disable_chunked_prefix_cache:
427
            logger.info("Chunked prefix cache is turned on.")
428

kk's avatar
kk committed
429
430
431
432
        if server_args.attention_backend == "aiter":
            if self.model_config.context_len > 8192:
                self.mem_fraction_static *= 0.85

433
    def init_torch_distributed(self):
434
        logger.info("Init torch distributed begin.")
435

436
437
438
439
440
441
442
443
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
444
445
        if self.device == "cuda":
            backend = "nccl"
446
        elif self.device == "xpu":
447
            backend = "xccl"
448
449
        elif self.device == "hpu":
            backend = "hccl"
450
451
        elif self.device == "cpu":
            backend = "gloo"
452
453
        elif self.device == "npu":
            backend = "hccl"
454

455
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
456
        if not self.server_args.enable_p2p_check:
457
458
            monkey_patch_p2p_access_check()

459
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
460
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
461
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
462
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
463
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
464
        set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
465
466

        if not self.is_draft_worker:
Mick's avatar
Mick committed
467
            # Only initialize the distributed environment on the target model worker.
468
469
            init_distributed_environment(
                backend=backend,
470
471
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
472
473
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
474
                timeout=self.server_args.dist_timeout,
475
            )
476
477
478
479
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
            )
480
481
482
483
484
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
485
                moe_dense_tp_size=self.server_args.moe_dense_tp_size,
486
                pp_size=self.server_args.pp_size,
487
            )
488

489
        min_per_gpu_memory = get_available_gpu_memory(
490
491
492
493
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
494
        )
495
        self.tp_group = get_tp_group()
496
        self.attention_tp_group = get_attention_tp_group()
497

498
        # Check memory for tensor parallelism
499
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
500
        if self.tp_size > 1:
501
            if min_per_gpu_memory < local_gpu_memory * 0.9:
502
503
504
505
506
507
508
509
510
511
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
512

513
514
515
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
516
        return min_per_gpu_memory
517

Lianmin Zheng's avatar
Lianmin Zheng committed
518
    def load_model(self):
519
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
520
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
521
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
522
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
523
524

        # This can reduce thread conflicts and speed up weight loading.
525
526
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
527
528
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
529
530
531
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
532
                self.server_args.dtype = "float16"
533
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
534
535
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
536

537
538
        set_cuda_arch()

539
        # Prepare the model config
540
541
542
543
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
544
545
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
546
547

        # Load the model
548
549
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
550
551
        monkey_patch_isinstance_for_vllm_base_layer()

552
553
554
555
556
557
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
558
        monkey_patch_vllm_parallel_state(reverse=True)
559
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
560

bjmsong's avatar
bjmsong committed
561
562
563
564
565
566
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
567
568
569
570
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
571
572
573
574
575
576
577
578
579
580
581
582
583
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

584
        # Parse other args
585
        self.sliding_window_size = (
586
587
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
588
589
            else None
        )
590
        self.dtype = self.model_config.dtype
591

592
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
593
        logger.info(
594
            f"Load weight end. "
595
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
596
            f"dtype={self.dtype}, "
597
598
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
599
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
600

601
602
603
604
605
606
607
608
609
610
611
612
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

613
    def update_expert_location(
614
615
616
        self,
        new_expert_location_metadata: ExpertLocationMetadata,
        update_layer_ids: List[int],
617
    ):
618
        self.expert_location_updater.update(
619
620
            self.model.routed_experts_weights_of_layer,
            new_expert_location_metadata,
621
            update_layer_ids=update_layer_ids,
622
623
624
625
            nnodes=self.server_args.nnodes,
            rank=self.tp_rank,
        )

626
627
628
629
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
630
        logger.info(
Chayenne's avatar
Chayenne committed
631
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
632
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
633
634
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
635
        target_device = torch.device(self.device)
636
        self.model_config.model_path = model_path
637
638
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
639
        # Only support DefaultModelLoader for now
640
641
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
642
643
            message = f"Failed to get model loader: {loader}."
            return False, message
644
645
646

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
647
                DefaultModelLoader.Source.init_new(config, self.model)
648
649
650
651
            )
            return iter

        def model_load_weights(model, iter):
652
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
653
654
            return model

655
        with set_default_torch_dtype(self.model_config.dtype):
656
            try:
657
                iter = get_weight_iter(self.model_config)
658
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
659
                message = f"Failed to get weights iterator: {e}."
660
661
662
663
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
664
665
666
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
667
668
                del iter
                gc.collect()
669
                iter = get_weight_iter(self.model_config)
670
671
672
673
674
675
676
677
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

678
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
679
        return True, "Succeeded to update model weights."
680

681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
709
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
774
        return True, "Success"
775

776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

793
794
795
796
797
798
799
800
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
801
            lora_backend=self.server_args.lora_backend,
802
803
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
804
805
806
        )
        logger.info("LoRA manager ready.")

807
    def profile_max_num_token(self, total_gpu_memory: int):
808
        available_gpu_memory = get_available_gpu_memory(
809
810
811
812
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
813
        )
814
815
816
817
818
        if self.is_draft_worker:
            num_layers = getattr(
                self.model_config.hf_config,
                "num_nextn_predict_layers",
                self.num_effective_layers,
819
            )
820
821
822
        else:
            num_layers = self.num_effective_layers
        if self.use_mla_backend:
823
824
            # FIXME: pipeline parallelism is not compatible with mla backend
            assert self.pp_size == 1
825
826
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
827
                * num_layers
828
                * torch._utils._element_size(self.kv_cache_dtype)
829
830
831
            )
        else:
            cell_size = (
832
                self.model_config.get_num_kv_heads(get_attention_tp_size())
833
                * self.model_config.head_dim
834
                * num_layers
835
                * 2
836
                * torch._utils._element_size(self.kv_cache_dtype)
837
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
838
839
840
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
841
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
842
843
        return max_num_token

844
    def init_memory_pool(
845
846
        self,
        total_gpu_memory: int,
847
848
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
849
    ):
850
851
852
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
853
            if _is_hip:  # Using natively supported format
HAI's avatar
HAI committed
854
855
856
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
857
858
859
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
860
861
862
863
864
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

865
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
866
867
868
869
870
871
872
873
874
875
876
877

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

878
879
880
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

881
882
883
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
884
                max_num_reqs = self.server_args.max_num_reqs
885
            else:
886
887
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
888
889
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
890
891
892
893
894
895
896
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
897
898
                    + 100
                )
899
900
901
902
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
903

904
905
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
906
                logging.warning(
907
908
909
910
911
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
912

913
914
915
916
917
918
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )

919
        if self.max_total_num_tokens <= 0:
920
            raise RuntimeError(
921
                "Not enough memory. Please try to increase --mem-fraction-static."
922
            )
923

924
925
        if self.req_to_token_pool is None:
            self.req_to_token_pool = ReqToTokenPool(
926
                size=max_num_reqs,
927
928
929
930
931
932
933
934
                max_context_len=self.model_config.context_len + 4,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

935
        if self.use_mla_backend:
936
937
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
938
                page_size=self.page_size,
939
                dtype=self.kv_cache_dtype,
940
941
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
942
943
944
945
                layer_num=(
                    self.model_config.num_hidden_layers
                    if not self.is_draft_worker
                    else self.model_config.hf_config.num_nextn_predict_layers
946
                ),  # PP is not compatible with mla backend
Zhang, Liangang's avatar
Zhang, Liangang committed
947
                device=self.device,
948
                enable_memory_saver=self.server_args.enable_memory_saver,
949
950
                start_layer=self.start_layer,
                end_layer=self.end_layer,
951
            )
Shuo Yang's avatar
Shuo Yang committed
952
953
954
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
955
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
956
                dtype=self.kv_cache_dtype,
957
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
958
                head_dim=self.model_config.head_dim,
959
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
960
961
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
962
                enable_memory_saver=self.server_args.enable_memory_saver,
963
964
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
965
            )
966
967
968
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
969
                page_size=self.page_size,
970
                dtype=self.kv_cache_dtype,
971
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
972
                head_dim=self.model_config.head_dim,
973
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
974
                device=self.device,
975
                enable_memory_saver=self.server_args.enable_memory_saver,
976
977
                start_layer=self.start_layer,
                end_layer=self.end_layer,
978
            )
979
980

        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
            if self.page_size == 1:
                self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
            else:
                self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
996
997
998
        else:
            assert self.is_draft_worker

999
        logger.info(
1000
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
1001
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
1002
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1003

Lianmin Zheng's avatar
Lianmin Zheng committed
1004
1005
1006
1007
1008
1009
1010
1011
1012
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

1013
1014
    def init_attention_backend(self):
        """Init attention kernel backend."""
1015
1016
1017
1018
1019
1020
1021
        if self.server_args.enable_two_batch_overlap:
            self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
        else:
            self.attn_backend = self._get_attention_backend()

    # TODO unify with 6338
    def _get_attention_backend(self):
1022
        if self.server_args.attention_backend == "flashinfer":
1023
1024
1025
1026
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
1027

1028
1029
1030
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
                    self.plan_stream_for_flashinfer = torch.cuda.Stream()
1031
                return FlashInferAttnBackend(self)
1032
1033
1034
1035
1036
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

1037
                return FlashInferMLAAttnBackend(self)
1038
1039
1040
        elif self.server_args.attention_backend == "aiter":
            from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend

1041
            return AiterAttnBackend(self)
1042
1043
1044
1045
1046
1047
        elif self.server_args.attention_backend == "triton":
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
1048
1049
1050
1051
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

1052
                return DoubleSparseAttnBackend(self)
1053
            else:
1054
1055
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

1056
                return TritonAttnBackend(self)
1057
        elif self.server_args.attention_backend == "torch_native":
1058
1059
1060
1061
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

1062
            return TorchNativeAttnBackend(self)
lukec's avatar
lukec committed
1063
1064
1065
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

1066
            return FlashMLABackend(self)
1067
        elif self.server_args.attention_backend == "fa3":
1068
1069
1070
1071
            assert (
                torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
            ) or torch.cuda.get_device_capability()[0] == 9, (
                "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1072
1073
1074
1075
1076
1077
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

1078
            return FlashAttentionBackend(self)
1079
1080
1081
1082
1083
        elif self.server_args.attention_backend == "cutlass_mla":
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

1084
            return CutlassMLABackend(self)
1085
1086
1087
1088
1089
1090
1091
        elif self.server_args.attention_backend == "intel_amx":
            from sglang.srt.layers.attention.intel_amx_backend import (
                IntelAMXAttnBackend,
            )

            logger.info(f"Intel AMX attention backend is enabled.")
            return IntelAMXAttnBackend(self)
1092
1093
1094
1095
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
1096

Shuo Yang's avatar
Shuo Yang committed
1097
1098
1099
1100
1101
1102
1103
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1104
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1105
1106
1107
1108
1109
1110
1111
1112
1113
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1114
    def init_cuda_graphs(self):
1115
        """Capture cuda graphs."""
1116
1117
        self.cuda_graph_runner = None

1118
        if not self.is_generation:
1119
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1120
1121
            return

1122
1123
        if self.server_args.disable_cuda_graph:
            return
1124

1125
        tic = time.perf_counter()
1126
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1127
        logger.info(
1128
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1129
        )
1130
        self.cuda_graph_runner = CudaGraphRunner(self)
1131
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1132
        logger.info(
1133
            f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1134
            f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
1135
        )
1136

1137
    def apply_torch_tp(self):
1138
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1139
1140
1141
1142
1143
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1144
1145
1146
    def forward_decode(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
1147
        self.attn_backend.init_forward_metadata(forward_batch)
1148
1149
1150
1151
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1152
        return self.model.forward(
1153
            forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
Lianmin Zheng's avatar
Lianmin Zheng committed
1154
1155
        )

1156
    def forward_extend(
1157
1158
1159
1160
1161
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
1162
1163
1164
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1178

1179
1180
1181
1182
1183
1184
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
1185
        return self.model.forward(
1186
1187
1188
1189
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
1190
1191
        )

1192
    def forward(
1193
1194
1195
1196
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
1197
1198
1199
1200
1201
1202
1203
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
        self.forward_pass_id += 1

        with get_global_expert_distribution_recorder().with_forward_pass(
            self.forward_pass_id,
            forward_batch,
        ):
1204
            output = self._forward_raw(
1205
1206
1207
                forward_batch, skip_attn_backend_init, pp_proxy_tensors
            )

1208
        if self.eplb_manager is not None:
1209
            self.eplb_manager.on_forward_pass_end()
1210
1211
1212

        return output

1213
1214
1215
1216
1217
    def _forward_raw(
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool,
        pp_proxy_tensors: Optional[PPProxyTensors],
1218
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1219
        can_run_cuda_graph = bool(
1220
1221
1222
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
1223
1224
        )
        if can_run_cuda_graph:
1225
            ret = self.cuda_graph_runner.replay(
1226
1227
1228
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1229
            )
1230
1231
        elif forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1232
        elif forward_batch.forward_mode.is_extend():
1233
            ret = self.forward_extend(
1234
1235
1236
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1237
            )
Ke Bao's avatar
Ke Bao committed
1238
        elif forward_batch.forward_mode.is_idle():
1239
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
1240
        else:
1241
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1242

1243
1244
        return ret, can_run_cuda_graph

1245
1246
1247
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1248
        # Apply logit bias
1249
1250
1251
1252
1253
1254
1255
1256
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1257
1258
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
1279

1280
1281
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1282
1283
1284
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1285
            forward_batch.sampling_info,
1286
1287
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1288
            forward_batch.token_ids_logprobs,
1289
        )
1290
1291
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1292
1293
1294
1295
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
1296
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
1297
1298
        if rope_scaling is None:
            return False
1299
1300
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1301

1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1318
1319
1320
1321
1322
1323
1324
1325
1326

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
1327
1328
1329
        monkey_patch_torch_reductions()
        tensor = tensor.get(tp_rank)
    return tensor.to(torch.cuda.current_device())
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])