model_runner.py 53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
18
import inspect
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
import time
23
24
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.distributed as dist
28

29
from sglang.srt import debug_utils
30
31
32
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
34
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
35
    get_tp_group,
36
    get_world_group,
zhyncs's avatar
zhyncs committed
37
38
    init_distributed_environment,
    initialize_model_parallel,
39
    set_custom_all_reduce,
40
    set_mscclpp_all_reduce,
zhyncs's avatar
zhyncs committed
41
)
42
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
43
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
44
45
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
46
    get_attention_tp_size,
47
48
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
49
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
50
51
52
from sglang.srt.layers.quantization import (
    deep_gemm_wrapper,
    monkey_patch_isinstance_for_vllm_base_layer,
53
)
54
from sglang.srt.layers.sampler import Sampler
55
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
56
from sglang.srt.layers.utils import is_sm100_supported
57
from sglang.srt.lora.lora_manager import LoRAManager
58
from sglang.srt.managers.eplb_manager import EPLBManager
59
60
61
62
63
64
from sglang.srt.managers.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
    set_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import (
65
    ExpertLocationMetadata,
66
67
68
69
    compute_initial_expert_location_metadata,
    get_global_expert_location_metadata,
    set_global_expert_location_metadata,
)
70
71
72
73
from sglang.srt.managers.schedule_batch import (
    GLOBAL_SERVER_ARGS_KEYS,
    global_server_args_dict,
)
74
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
75
    DoubleSparseTokenToKVPool,
76
77
78
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
79
    TokenToKVPoolAllocator,
80
)
Lianmin Zheng's avatar
Lianmin Zheng committed
81
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
Yineng Zhang's avatar
Yineng Zhang committed
82
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
83
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
84
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
85
from sglang.srt.model_loader import get_model
86
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
Lianmin Zheng's avatar
Lianmin Zheng committed
87
from sglang.srt.model_loader.utils import set_default_torch_dtype
88
from sglang.srt.model_loader.weight_utils import default_weight_loader
89
from sglang.srt.patch_torch import monkey_patch_torch_reductions
90
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
91
from sglang.srt.server_args import ServerArgs
92
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
93
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
94
from sglang.srt.utils import (
95
    MultiprocessingSerializer,
96
    cpu_has_amx_support,
97
    dynamic_import,
98
    enable_show_time_cost,
99
    get_available_gpu_memory,
100
    get_bool_env_var,
101
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
102
    is_cuda,
103
    is_fa3_default_architecture,
104
    is_flashinfer_available,
HAI's avatar
HAI committed
105
    is_hip,
106
    is_hopper_with_cuda_12_3,
107
    is_no_spec_infer_or_topk_one,
108
    monkey_patch_p2p_access_check,
109
    monkey_patch_vllm_gguf_config,
110
    set_cpu_offload_max_bytes,
111
    set_cuda_arch,
112
)
113

114
_is_hip = is_hip()
115
_is_cpu_amx_available = cpu_has_amx_support()
116

Lianmin Zheng's avatar
Lianmin Zheng committed
117
# Use a small KV cache pool size for tests in CI
118
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
Lianmin Zheng's avatar
Lianmin Zheng committed
119
120

# Detect stragger ranks in model loading
121
122
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300

Lianmin Zheng's avatar
Lianmin Zheng committed
123
124
logger = logging.getLogger(__name__)

125

126
127
128
129
130
131
132
133
134
135
136
137
138
class RankZeroFilter(logging.Filter):
    """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

    def __init__(self, is_rank_zero):
        super().__init__()
        self.is_rank_zero = is_rank_zero

    def filter(self, record):
        if record.levelno == logging.INFO:
            return self.is_rank_zero
        return True


Lianmin Zheng's avatar
Lianmin Zheng committed
139
class ModelRunner:
140
141
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
    def __init__(
        self,
144
        model_config: ModelConfig,
145
146
147
148
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
149
150
        pp_rank: int,
        pp_size: int,
151
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
152
        server_args: ServerArgs,
153
        is_draft_worker: bool = False,
154
155
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
156
    ):
157
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
158
159
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
160
        self.device = server_args.device
161
        self.gpu_id = gpu_id
162
163
164
165

        # Apply the rank zero filter to logger
        if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
            logger.addFilter(RankZeroFilter(tp_rank == 0))
Lianmin Zheng's avatar
Lianmin Zheng committed
166
167
        self.tp_rank = tp_rank
        self.tp_size = tp_size
168
        self.dp_size = server_args.dp_size
169
170
        self.pp_rank = pp_rank
        self.pp_size = pp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
171
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
172
        self.server_args = server_args
173
        self.is_draft_worker = is_draft_worker
174
175
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
176
177
178
        self.is_multimodal_chunked_prefill_supported = (
            model_config.is_multimodal_chunked_prefill_supported
        )
179
180
181
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
182
        self.page_size = server_args.page_size
183
184
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Baizhou Zhang's avatar
Baizhou Zhang committed
185
        self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
Chang Su's avatar
Chang Su committed
186
        self.attention_chunk_size = model_config.attention_chunk_size
Ke Bao's avatar
Ke Bao committed
187

188
189
        self.forward_pass_id = 0

190
        # Model-specific adjustment
191
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
192

193
194
        if server_args.show_time_cost:
            enable_show_time_cost()
195
196

        # Global vars
197
        global_server_args_dict.update(
198
199
200
            {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
            | {
                # TODO it is indeed not a "server args"
201
                "use_mla_backend": self.use_mla_backend,
202
                "speculative_algorithm": self.spec_algorithm,
203
204
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
205

206
        # CPU offload
207
208
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

209
        # Get memory before model loading
210
        min_per_gpu_memory = self.init_torch_distributed()
211

212
        # Update deep gemm configure
213
214
        if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
            deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
215

Lianmin Zheng's avatar
Lianmin Zheng committed
216
        # If it is a draft model, tp_group can be different
217
218
        self.initialize(min_per_gpu_memory)

219
220
221
222
223
        # temporary cached values
        self.support_pp = (
            "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
        )

224
225
    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
226

227
228
229
230
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        if not self.is_draft_worker:
            set_global_expert_location_metadata(
                compute_initial_expert_location_metadata(server_args, self.model_config)
            )
            if self.tp_rank == 0 and get_bool_env_var(
                "SGLANG_LOG_EXPERT_LOCATION_METADATA"
            ):
                logger.info(
                    f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
                )

            set_global_expert_distribution_recorder(
                ExpertDistributionRecorder.init_new(
                    server_args,
                    get_global_expert_location_metadata(),
                    rank=self.tp_rank,
                )
            )

250
251
252
253
254
        self.eplb_manager = (
            EPLBManager(self)
            if self.server_args.enable_eplb and (not self.is_draft_worker)
            else None
        )
255
        self.expert_location_updater = ExpertLocationUpdater()
256

257
        # Load the model
258
        self.sampler = Sampler()
259
        self.load_model()
260

261
262
263
264
265
266
        self.start_layer = getattr(self.model, "start_layer", 0)
        self.end_layer = getattr(
            self.model, "end_layer", self.model_config.num_hidden_layers
        )
        self.num_effective_layers = self.end_layer - self.start_layer

267
        # Apply torchao quantization
268
269
270
271
272
273
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
274

275
        # Apply torch TP if the model supports it
276
277
278
279
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

280
        # Init lora
281
282
283
284
        # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
        # a new server arg `enable_lora` to control whether to init LoRA manager to be more
        # explicit, as it is perfectly valid to start a server with an empty lora_paths and
        # load LoRA adapters dynamically later.
285
286
        if server_args.lora_paths is not None:
            self.init_lora_manager()
287
288

        # Init memory pool and attention backends
289
290
        self.init_memory_pool(
            min_per_gpu_memory,
291
            server_args.max_running_requests,
292
293
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
294
295
296
297
298
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
299
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
300
            self.init_attention_backend()
301

James Liu's avatar
James Liu committed
302
303
304
305
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
            self.model.set_eagle3_layers_to_capture()

306
307
308
    def model_specific_adjustment(self):
        server_args = self.server_args

309
310
311
        if (
            server_args.attention_backend == "intel_amx"
            and server_args.device == "cpu"
312
            and not _is_cpu_amx_available
313
314
315
316
317
318
        ):
            logger.info(
                "The current platform does not support Intel AMX, will fallback to torch_native backend."
            )
            server_args.attention_backend = "torch_native"

319
        if server_args.attention_backend is None:
320
            """
Lianmin Zheng's avatar
Lianmin Zheng committed
321
322
            Auto select the fastest attention backend.

323
324
325
326
327
            1. Models with MHA Architecture (e.g: Llama, QWen)
                1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
                1.2 In other cases, we will use flashinfer if available, otherwise use triton.
            2. Models with MLA Architecture and using FA3
                2.1 We will use FA3 backend on hopper.
328
329
                2.2 We will use Flashinfer backend on blackwell.
                2.3 Otherwise, we will use triton backend.
330
331
            """

332
            if not self.use_mla_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
333
                # MHA architecture
334
                if (
335
                    is_hopper_with_cuda_12_3()
336
337
338
339
                    and is_no_spec_infer_or_topk_one(server_args)
                    and is_fa3_default_architecture(self.model_config.hf_config)
                ):
                    server_args.attention_backend = "fa3"
340
341
                elif _is_hip:
                    server_args.attention_backend = "aiter"
342
343
344
345
                else:
                    server_args.attention_backend = (
                        "flashinfer" if is_flashinfer_available() else "triton"
                    )
346
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
347
                # MLA architecture
348
                if is_hopper_with_cuda_12_3():
349
                    server_args.attention_backend = "fa3"
350
351
                elif is_sm100_supported():
                    server_args.attention_backend = "flashinfer"
352
353
354
355
356
357
358
359
360
                elif _is_hip:
                    head_num = self.model_config.get_num_kv_heads(self.tp_size)
                    # TODO current aiter only support head number 16 or 128 head number
                    if (
                        head_num == 128 or head_num == 16
                    ) and self.spec_algorithm.is_none():
                        server_args.attention_backend = "aiter"
                    else:
                        server_args.attention_backend = "triton"
361
362
                else:
                    server_args.attention_backend = "triton"
363
364
365
            logger.info(
                f"Attention backend not set. Use {server_args.attention_backend} backend by default."
            )
366
        elif self.use_mla_backend:
367
            if server_args.device != "cpu":
368
                if server_args.attention_backend in [
369
                    "aiter",
370
371
372
373
                    "flashinfer",
                    "fa3",
                    "triton",
                    "flashmla",
374
                    "cutlass_mla",
375
                ]:
376
377
378
                    logger.info(
                        f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
                    )
379
                else:
380
381
382
383
                    raise ValueError(
                        f"Invalid attention backend for MLA: {server_args.attention_backend}"
                    )
            else:
384
385
386
387
                if server_args.attention_backend != "intel_amx":
                    raise ValueError(
                        "MLA optimization not supported on CPU except for intel_amx backend."
                    )
388

389
390
391
392
393
394
395
396
397
398
        if (
            server_args.attention_backend == "fa3"
            and server_args.kv_cache_dtype == "fp8_e5m2"
        ):
            logger.warning(
                "FlashAttention3 only supports fp8_e4m3 if using FP8; "
                "Setting attention backend to triton."
            )
            server_args.attention_backend = "triton"

399
        if server_args.enable_double_sparsity:
400
401
402
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
403
404
405
406
407
408
409
410
411
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
Mick's avatar
Mick committed
412
            self.mem_fraction_static *= 0.90
413
            logger.info(
414
415
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
416
            )
417
418
419
420
421
422
            if not self.is_multimodal_chunked_prefill_supported:
                server_args.chunked_prefill_size = -1
                logger.info(
                    f"Automatically turn of --chunked-prefill-size as it is not supported for "
                    f"{self.model_config.hf_config.model_type}"
                )
423

424
425
426
        if not self.use_mla_backend:
            server_args.disable_chunked_prefix_cache = True
        elif self.page_size > 1:
427
            logger.info("Disable chunked prefix cache when page size > 1.")
428
429
430
            server_args.disable_chunked_prefix_cache = True

        if not server_args.disable_chunked_prefix_cache:
431
            logger.info("Chunked prefix cache is turned on.")
432

kk's avatar
kk committed
433
434
435
436
        if server_args.attention_backend == "aiter":
            if self.model_config.context_len > 8192:
                self.mem_fraction_static *= 0.85

437
    def init_torch_distributed(self):
438
        logger.info("Init torch distributed begin.")
439

440
441
442
443
444
445
446
447
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
448
449
        if self.device == "cuda":
            backend = "nccl"
450
        elif self.device == "xpu":
451
            backend = "xccl"
452
453
        elif self.device == "hpu":
            backend = "hccl"
454
455
        elif self.device == "cpu":
            backend = "gloo"
456
457
        elif self.device == "npu":
            backend = "hccl"
458

459
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
460
        if not self.server_args.enable_p2p_check:
461
462
            monkey_patch_p2p_access_check()

463
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
464
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
465
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
466
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
467
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
468
        set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
469
470

        if not self.is_draft_worker:
Mick's avatar
Mick committed
471
            # Only initialize the distributed environment on the target model worker.
472
473
            init_distributed_environment(
                backend=backend,
474
475
                world_size=self.tp_size * self.pp_size,
                rank=self.tp_size * self.pp_rank + self.tp_rank,
476
477
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
478
                timeout=self.server_args.dist_timeout,
479
            )
480
481
482
483
            initialize_model_parallel(
                tensor_model_parallel_size=self.tp_size,
                pipeline_model_parallel_size=self.pp_size,
            )
484
485
486
487
488
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
489
                moe_dense_tp_size=self.server_args.moe_dense_tp_size,
490
                pp_size=self.server_args.pp_size,
491
            )
492

493
        min_per_gpu_memory = get_available_gpu_memory(
494
495
496
497
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
498
        )
499
        self.tp_group = get_tp_group()
500
        self.attention_tp_group = get_attention_tp_group()
501

502
        # Check memory for tensor parallelism
503
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
504
        if self.tp_size > 1:
505
            if min_per_gpu_memory < local_gpu_memory * 0.9:
506
507
508
509
510
511
512
513
514
515
                if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
                    logger.warning(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
                else:
                    raise ValueError(
                        "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                        f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
516

517
518
519
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
520
        return min_per_gpu_memory
521

Lianmin Zheng's avatar
Lianmin Zheng committed
522
    def load_model(self):
523
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
524
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
525
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
526
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
527
528

        # This can reduce thread conflicts and speed up weight loading.
529
530
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
531
532
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
533
534
535
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
Zhang, Liangang's avatar
Zhang, Liangang committed
536
                self.server_args.dtype = "float16"
537
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
538
539
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
540

541
542
        set_cuda_arch()

543
        # Prepare the model config
544
545
546
547
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
548
549
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
550
551

        # Load the model
552
553
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
554
555
        monkey_patch_isinstance_for_vllm_base_layer()

556
        with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
557
558
559
560
561
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
562
        monkey_patch_vllm_parallel_state(reverse=True)
563
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
564

bjmsong's avatar
bjmsong committed
565
566
567
568
569
570
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
571
572
573
574
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
bjmsong's avatar
bjmsong committed
575
576
577
578
579
580
581
582
583
584
585
586
587
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

588
        # Parse other args
589
        self.sliding_window_size = (
590
591
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
592
593
            else None
        )
594
        self.dtype = self.model_config.dtype
595

596
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
597
        logger.info(
598
            f"Load weight end. "
599
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
600
            f"dtype={self.dtype}, "
601
602
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
603
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
604

605
606
607
608
609
610
611
612
613
614
615
616
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

617
    def update_expert_location(
618
619
620
        self,
        new_expert_location_metadata: ExpertLocationMetadata,
        update_layer_ids: List[int],
621
    ):
622
        self.expert_location_updater.update(
623
624
            self.model.routed_experts_weights_of_layer,
            new_expert_location_metadata,
625
            update_layer_ids=update_layer_ids,
626
627
628
629
            nnodes=self.server_args.nnodes,
            rank=self.tp_rank,
        )

630
631
632
633
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
634
        logger.info(
Chayenne's avatar
Chayenne committed
635
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
636
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
637
638
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
639
        target_device = torch.device(self.device)
640
        self.model_config.model_path = model_path
641
642
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
643
        # Only support DefaultModelLoader for now
644
645
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
646
647
            message = f"Failed to get model loader: {loader}."
            return False, message
648
649
650

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
651
                DefaultModelLoader.Source.init_new(config, self.model)
652
653
654
655
            )
            return iter

        def model_load_weights(model, iter):
656
            DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
657
658
            return model

659
        with set_default_torch_dtype(self.model_config.dtype):
660
            try:
661
                iter = get_weight_iter(self.model_config)
662
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
663
                message = f"Failed to get weights iterator: {e}."
664
665
666
667
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
668
669
670
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
671
672
                del iter
                gc.collect()
673
                iter = get_weight_iter(self.model_config)
674
675
676
677
678
679
680
681
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

682
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
683
        return True, "Succeeded to update model weights."
684

685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
713
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

763
764
765
766
767
768
769
770
771
772
773
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
774
775
776
        elif load_format in self.server_args.custom_weight_loader:
            custom_loader = dynamic_import(load_format)
            custom_loader(self.model, named_tensors)
777
778
779
780
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
781
        return True, "Success"
782

783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

800
801
802
803
804
805
806
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
807
            lora_backend=self.server_args.lora_backend,
808
809
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
810
        )
811
        self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
812
813
        logger.info("LoRA manager ready.")

814
    def profile_max_num_token(self, total_gpu_memory: int):
815
        available_gpu_memory = get_available_gpu_memory(
816
817
818
819
            self.device,
            self.gpu_id,
            distributed=get_world_group().world_size > 1,
            cpu_group=get_world_group().cpu_group,
820
        )
821
822
823
824
825
        if self.is_draft_worker:
            num_layers = getattr(
                self.model_config.hf_config,
                "num_nextn_predict_layers",
                self.num_effective_layers,
826
            )
827
828
829
        else:
            num_layers = self.num_effective_layers
        if self.use_mla_backend:
830
831
            # FIXME: pipeline parallelism is not compatible with mla backend
            assert self.pp_size == 1
832
833
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
834
                * num_layers
835
                * torch._utils._element_size(self.kv_cache_dtype)
836
837
838
            )
        else:
            cell_size = (
839
                self.model_config.get_num_kv_heads(get_attention_tp_size())
840
                * self.model_config.head_dim
841
                * num_layers
842
                * 2
843
                * torch._utils._element_size(self.kv_cache_dtype)
844
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
845
846
847
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
848
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
849
850
        return max_num_token

851
    def init_memory_pool(
852
853
        self,
        total_gpu_memory: int,
854
855
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
856
    ):
857
858
859
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
860
            if _is_hip:  # Using natively supported format
HAI's avatar
HAI committed
861
862
863
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
864
865
866
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
867
868
869
870
871
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

872
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
873
874
875
876
877
878
879
880
881
882
883
884

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

885
886
887
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

888
889
890
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
891
                max_num_reqs = self.server_args.max_num_reqs
892
            else:
893
894
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
895
896
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
897
898
899
900
901
902
903
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
904
905
                    + 100
                )
906
907
908
909
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
910

911
912
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
913
                logging.warning(
914
915
916
917
918
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
919

920
921
922
923
924
925
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )

926
        if self.max_total_num_tokens <= 0:
927
            raise RuntimeError(
928
                "Not enough memory. Please try to increase --mem-fraction-static."
929
            )
930

931
        if self.req_to_token_pool is None:
Byron Hsu's avatar
Byron Hsu committed
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
            if self.server_args.disaggregation_mode == "decode":
                from sglang.srt.disaggregation.decode import DecodeReqToTokenPool

                # subscribe memory for pre-allocated requests
                # if max_num_reqs <= 32, we pre-allocate 2x requests
                pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
                self.req_to_token_pool = DecodeReqToTokenPool(
                    size=max_num_reqs,
                    max_context_len=self.model_config.context_len + 4,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                    pre_alloc_size=pre_alloc_size,
                )
            else:
                self.req_to_token_pool = ReqToTokenPool(
                    size=max_num_reqs,
                    max_context_len=self.model_config.context_len + 4,
                    device=self.device,
                    enable_memory_saver=self.server_args.enable_memory_saver,
                )
952
953
954
955
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

956
        if self.use_mla_backend:
957
958
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
959
                page_size=self.page_size,
960
                dtype=self.kv_cache_dtype,
961
962
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
963
964
965
966
                layer_num=(
                    self.model_config.num_hidden_layers
                    if not self.is_draft_worker
                    else self.model_config.hf_config.num_nextn_predict_layers
967
                ),  # PP is not compatible with mla backend
Zhang, Liangang's avatar
Zhang, Liangang committed
968
                device=self.device,
969
                enable_memory_saver=self.server_args.enable_memory_saver,
970
971
                start_layer=self.start_layer,
                end_layer=self.end_layer,
972
            )
Shuo Yang's avatar
Shuo Yang committed
973
974
975
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
976
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
977
                dtype=self.kv_cache_dtype,
978
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
979
                head_dim=self.model_config.head_dim,
980
                layer_num=self.num_effective_layers,
Shuo Yang's avatar
Shuo Yang committed
981
982
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
983
                enable_memory_saver=self.server_args.enable_memory_saver,
984
985
                start_layer=self.start_layer,
                end_layer=self.end_layer,
Shuo Yang's avatar
Shuo Yang committed
986
            )
987
988
989
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
990
                page_size=self.page_size,
991
                dtype=self.kv_cache_dtype,
992
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
993
                head_dim=self.model_config.head_dim,
994
                layer_num=self.num_effective_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
995
                device=self.device,
996
                enable_memory_saver=self.server_args.enable_memory_saver,
997
998
                start_layer=self.start_layer,
                end_layer=self.end_layer,
999
            )
1000
1001

        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
            if self.page_size == 1:
                self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
            else:
                self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
1017
1018
1019
        else:
            assert self.is_draft_worker

1020
        logger.info(
1021
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
1022
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
1023
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1024

Lianmin Zheng's avatar
Lianmin Zheng committed
1025
1026
1027
1028
1029
1030
1031
1032
1033
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

1034
1035
    def init_attention_backend(self):
        """Init attention kernel backend."""
1036
1037
1038
1039
1040
1041
1042
        if self.server_args.enable_two_batch_overlap:
            self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
        else:
            self.attn_backend = self._get_attention_backend()

    # TODO unify with 6338
    def _get_attention_backend(self):
1043
        if self.server_args.attention_backend == "flashinfer":
1044
1045
1046
1047
            if not self.use_mla_backend:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                )
1048

1049
1050
1051
                # Init streams
                if self.server_args.speculative_algorithm == "EAGLE":
                    self.plan_stream_for_flashinfer = torch.cuda.Stream()
1052
                return FlashInferAttnBackend(self)
1053
1054
1055
1056
1057
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                )

1058
                return FlashInferMLAAttnBackend(self)
1059
1060
1061
        elif self.server_args.attention_backend == "aiter":
            from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend

1062
            return AiterAttnBackend(self)
1063
1064
1065
1066
1067
1068
        elif self.server_args.attention_backend == "triton":
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
1069
1070
1071
1072
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

1073
                return DoubleSparseAttnBackend(self)
1074
            else:
1075
1076
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

1077
                return TritonAttnBackend(self)
1078
        elif self.server_args.attention_backend == "torch_native":
1079
1080
1081
1082
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

1083
            return TorchNativeAttnBackend(self)
lukec's avatar
lukec committed
1084
1085
1086
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

1087
            return FlashMLABackend(self)
1088
        elif self.server_args.attention_backend == "fa3":
1089
1090
1091
1092
            assert (
                torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
            ) or torch.cuda.get_device_capability()[0] == 9, (
                "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
1093
1094
1095
1096
1097
1098
                "Please use `--attention-backend flashinfer`."
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

1099
            return FlashAttentionBackend(self)
1100
1101
1102
1103
1104
        elif self.server_args.attention_backend == "cutlass_mla":
            from sglang.srt.layers.attention.cutlass_mla_backend import (
                CutlassMLABackend,
            )

1105
            return CutlassMLABackend(self)
1106
1107
1108
1109
1110
1111
1112
        elif self.server_args.attention_backend == "intel_amx":
            from sglang.srt.layers.attention.intel_amx_backend import (
                IntelAMXAttnBackend,
            )

            logger.info(f"Intel AMX attention backend is enabled.")
            return IntelAMXAttnBackend(self)
1113
1114
1115
1116
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
1117

Shuo Yang's avatar
Shuo Yang committed
1118
1119
1120
1121
1122
1123
1124
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

1125
        for i in range(self.start_layer, self.end_layer):
Shuo Yang's avatar
Shuo Yang committed
1126
1127
1128
1129
1130
1131
1132
1133
1134
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

1135
    def init_cuda_graphs(self):
1136
        """Capture cuda graphs."""
1137
1138
        self.cuda_graph_runner = None

1139
        if not self.is_generation:
1140
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1141
1142
            return

1143
1144
        if self.server_args.disable_cuda_graph:
            return
1145

1146
        tic = time.perf_counter()
1147
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1148
        logger.info(
1149
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1150
        )
1151
        self.cuda_graph_runner = CudaGraphRunner(self)
1152
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1153
        logger.info(
1154
            f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1155
            f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
1156
        )
1157

1158
    def apply_torch_tp(self):
1159
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1160
1161
1162
1163
1164
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

1165
1166
1167
    def forward_decode(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
1168
        self.attn_backend.init_forward_metadata(forward_batch)
1169
1170
1171
1172
        # FIXME: add pp_proxy_tensors arg to all models
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1173
        return self.model.forward(
1174
            forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
Lianmin Zheng's avatar
Lianmin Zheng committed
1175
1176
        )

1177
    def forward_extend(
1178
1179
1180
1181
1182
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors=None,
    ) -> LogitsProcessorOutput:
1183
1184
1185
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
        if forward_batch.input_embeds is not None:
            kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
        if not self.is_generation:
            kwargs["get_embedding"] = True
        return self.model.forward(
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1199

1200
1201
1202
1203
1204
1205
    def forward_idle(
        self, forward_batch: ForwardBatch, pp_proxy_tensors=None
    ) -> LogitsProcessorOutput:
        kwargs = {}
        if self.support_pp:
            kwargs["pp_proxy_tensors"] = pp_proxy_tensors
Ke Bao's avatar
Ke Bao committed
1206
        return self.model.forward(
1207
1208
1209
1210
            forward_batch.input_ids,
            forward_batch.positions,
            forward_batch,
            **kwargs,
Ke Bao's avatar
Ke Bao committed
1211
1212
        )

1213
    def forward(
1214
1215
1216
1217
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool = False,
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
1218
1219
1220
1221
1222
1223
1224
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
        self.forward_pass_id += 1

        with get_global_expert_distribution_recorder().with_forward_pass(
            self.forward_pass_id,
            forward_batch,
        ):
1225
            output = self._forward_raw(
1226
1227
1228
                forward_batch, skip_attn_backend_init, pp_proxy_tensors
            )

1229
        if self.eplb_manager is not None:
1230
            self.eplb_manager.on_forward_pass_end()
1231
1232
1233

        return output

1234
1235
1236
1237
1238
    def _forward_raw(
        self,
        forward_batch: ForwardBatch,
        skip_attn_backend_init: bool,
        pp_proxy_tensors: Optional[PPProxyTensors],
1239
    ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1240
        can_run_cuda_graph = bool(
1241
1242
1243
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
1244
1245
        )
        if can_run_cuda_graph:
1246
            ret = self.cuda_graph_runner.replay(
1247
1248
1249
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1250
            )
1251
1252
        elif forward_batch.forward_mode.is_decode():
            ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1253
        elif forward_batch.forward_mode.is_extend():
1254
            ret = self.forward_extend(
1255
1256
1257
                forward_batch,
                skip_attn_backend_init=skip_attn_backend_init,
                pp_proxy_tensors=pp_proxy_tensors,
1258
            )
Ke Bao's avatar
Ke Bao committed
1259
        elif forward_batch.forward_mode.is_idle():
1260
            ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
Lianmin Zheng's avatar
Lianmin Zheng committed
1261
        else:
1262
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1263

1264
1265
        return ret, can_run_cuda_graph

1266
1267
1268
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1269
        # Apply logit bias
1270
1271
1272
1273
1274
1275
1276
1277
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1278
1279
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )
1300

1301
1302
        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1303
1304
1305
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1306
            forward_batch.sampling_info,
1307
1308
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1309
            forward_batch.token_ids_logprobs,
1310
        )
1311
1312
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1313
1314
1315
1316
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
1317
        rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
Yineng Zhang's avatar
Yineng Zhang committed
1318
1319
        if rope_scaling is None:
            return False
1320
1321
        is_mrope_enabled = "mrope_section" in rope_scaling
        return is_mrope_enabled
1322

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1339
1340
1341
1342
1343
1344
1345
1346
1347

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
1348
1349
1350
        monkey_patch_torch_reductions()
        tensor = tensor.get(tp_rank)
    return tensor.to(torch.cuda.current_device())
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])