model_runner.py 38.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
17
import collections
import datetime
18
import gc
Shuo Yang's avatar
Shuo Yang committed
19
import json
20
import logging
21
import os
22
import time
23
24
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.distributed as dist
28
29
30
31
32

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
33
34
35
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
36
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
37
)
38
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
Shuo Yang's avatar
Shuo Yang committed
39
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
40
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
41
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
42
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
43
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
44
45
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
46
    get_attention_tp_size,
47
48
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
49
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
50
from sglang.srt.layers.sampler import Sampler
51
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
52
from sglang.srt.lora.lora_manager import LoRAManager
53
from sglang.srt.managers.schedule_batch import global_server_args_dict
54
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
55
    DoubleSparseTokenToKVPool,
56
57
58
59
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
Yineng Zhang's avatar
Yineng Zhang committed
60
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
61
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
62
from sglang.srt.model_loader import get_model
63
from sglang.srt.model_loader.weight_utils import default_weight_loader
64
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
65
from sglang.srt.server_args import ServerArgs
66
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
67
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
68
from sglang.srt.utils import (
69
    MultiprocessingSerializer,
70
    enable_show_time_cost,
71
    get_available_gpu_memory,
72
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
73
    is_cuda,
HAI's avatar
HAI committed
74
    is_hip,
75
    monkey_patch_p2p_access_check,
76
    monkey_patch_vllm_gguf_config,
77
    set_cpu_offload_max_bytes,
78
    set_cuda_arch,
79
)
80
from sglang.utils import get_exception_traceback
81

Ying Sheng's avatar
Ying Sheng committed
82
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
83

Lianmin Zheng's avatar
Lianmin Zheng committed
84

85
86
87
88
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300


Lianmin Zheng's avatar
Lianmin Zheng committed
89
class ModelRunner:
90
91
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
    def __init__(
        self,
94
        model_config: ModelConfig,
95
96
97
98
99
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
100
        server_args: ServerArgs,
101
        is_draft_worker: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
102
    ):
103
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
104
105
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
106
        self.device = server_args.device
107
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
110
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
111
        self.server_args = server_args
112
        self.is_draft_worker = is_draft_worker
113
114
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
115
        self.should_log = tp_rank == 0
116
117
118
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
Ke Bao's avatar
Ke Bao committed
119

120
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
121
122
123
124
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
125
126
            # TODO: add MLA optimization on CPU
            if self.server_args.device != "cpu":
127
128
                if server_args.enable_flashinfer_mla:
                    logger.info(
129
                        "MLA optimization is turned on. Use flashinfer mla backend."
130
                    )
131
                    self.server_args.attention_backend = "flashinfer_mla"
132
133
134
                else:
                    logger.info("MLA optimization is turned on. Use triton backend.")
                    self.server_args.attention_backend = "triton"
Ke Bao's avatar
Ke Bao committed
135

Shuo Yang's avatar
Shuo Yang committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

150
        if self.is_multimodal:
Lianmin Zheng's avatar
Lianmin Zheng committed
151
            self.mem_fraction_static *= 0.95
152
153
154
155
156
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
            )

157
158
159
160
161
            if self.model_config.hf_config.architectures == [
                "MllamaForConditionalGeneration"
            ]:
                logger.info("Automatically turn off --chunked-prefill-size for mllama.")
                server_args.chunked_prefill_size = -1
162

Yineng Zhang's avatar
Yineng Zhang committed
163
164
165
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
166
                # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
167
168
169
170
                logger.info(
                    "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
                )
                server_args.chunked_prefill_size = -1
171
                server_args.disable_radix_cache = True
172

173
174
175
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
176
        if server_args.disable_outlines_disk_cache:
177
178
            from outlines.caching import disable_cache

179
180
            disable_cache()

181
182
        global_server_args_dict.update(
            {
183
184
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
185
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
186
                "disable_mla": server_args.disable_mla,
187
                "torchao_config": server_args.torchao_config,
188
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
189
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
190
                "enable_ep_moe": server_args.enable_ep_moe,
191
                "device": server_args.device,
192
193
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
194
                "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
195
                "disable_radix_cache": server_args.disable_radix_cache,
196
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
197
198
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
199
200
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
201

202
203
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

204
        # Get memory before model loading
205
        min_per_gpu_memory = self.init_torch_distributed()
206

207
208
209
210
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

211
        # Load the model
212
        self.sampler = Sampler()
213
        self.load_model()
214

215
216
217
218
219
220
221
222
223
224
225
226
        # Handle the case where some of models don't finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

227
        # Apply torchao quantization
228
229
230
231
232
233
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
234

235
        # Apply torch TP if the model supports it
236
237
238
239
240
241
242
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

243
        # Init memory pool and attention backends
244
245
        if server_args.lora_paths is not None:
            self.init_lora_manager()
246
247
        self.init_memory_pool(
            min_per_gpu_memory,
248
            server_args.max_running_requests,
249
250
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
251
252
253
254
255
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
256
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
257
            self.init_attention_backend()
258
259

    def init_torch_distributed(self):
260
        logger.info("Init torch distributed begin.")
261

262
        torch.get_device_module(self.device).set_device(self.gpu_id)
Zhang, Liangang's avatar
Zhang, Liangang committed
263
264
        if self.device == "cuda":
            backend = "nccl"
265
        elif self.device == "xpu":
266
            # TODO(liangan1): Just use gloo to bypass the initilization fail
267
            # Need to use xccl for xpu backend in the future
268
            backend = "gloo"
269
270
        elif self.device == "hpu":
            backend = "hccl"
271
272
        elif self.device == "cpu":
            backend = "gloo"
273

274
        if not self.server_args.enable_p2p_check:
275
276
            monkey_patch_p2p_access_check()

277
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
278
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
279
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
280
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
281
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
282
283

        if not self.is_draft_worker:
Mick's avatar
Mick committed
284
            # Only initialize the distributed environment on the target model worker.
285
286
287
288
289
290
            init_distributed_environment(
                backend=backend,
                world_size=self.tp_size,
                rank=self.tp_rank,
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
291
                timeout=self.server_args.dist_timeout,
292
293
            )
            initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
294
295
296
297
298
299
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
            )
300

301
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
302
            self.device, self.gpu_id, distributed=self.tp_size > 1
303
        )
304
        self.tp_group = get_tp_group()
305
        self.attention_tp_group = get_attention_tp_group()
306

307
        # Check memory for tensor parallelism
308
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
309
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
310
            if min_per_gpu_memory < local_gpu_memory * 0.9:
311
312
313
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
314

315
        return min_per_gpu_memory
316

Lianmin Zheng's avatar
Lianmin Zheng committed
317
    def load_model(self):
318
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
319
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
320
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
321
322

        # This can reduce thread conflicts and speed up weight loading.
323
324
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
325
326
327
328
329
330
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
331
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
332
333
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
334

335
336
        set_cuda_arch()

337
        # Prepare the model config
338
339
340
341
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
342
343
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
344
345

        # Load the model
346
347
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
348
349
350
351
352
353
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
354
        monkey_patch_vllm_parallel_state(reverse=True)
355

bjmsong's avatar
bjmsong committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

379
        # Parse other args
380
        self.sliding_window_size = (
381
382
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
383
384
            else None
        )
385
        self.dtype = self.model_config.dtype
386

387
        logger.info(
388
            f"Load weight end. "
389
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
390
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
391
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
392
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
393

394
395
396
397
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
398
        from sglang.srt.model_loader.loader import (
399
400
401
402
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
403
        from sglang.srt.model_loader.utils import set_default_torch_dtype
404
405

        logger.info(
Chayenne's avatar
Chayenne committed
406
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
407
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
408
409
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
410
        target_device = torch.device(self.device)
411
        self.model_config.model_path = model_path
412
413
414
415
416
        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
417
418
            message = f"Failed to get model loader: {loader}."
            return False, message
419
420
421

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
422
                DefaultModelLoader.Source(
423
                    config.model_path,
424
425
426
427
428
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
429
430
431
432
433
434
435
436
437
438
439
440
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

441
        with set_default_torch_dtype(self.model_config.dtype):
442
            try:
443
                iter = get_weight_iter(self.model_config)
444
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
445
                message = f"Failed to get weights iterator: {e}."
446
447
448
449
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
450
451
452
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
453
454
                del iter
                gc.collect()
455
                iter = get_weight_iter(self.model_config)
456
457
458
459
460
461
462
463
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

464
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
465
        return True, "Succeeded to update model weights."
466

467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
495
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
561
        return True, "Success"
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

580
581
582
583
584
585
586
587
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
588
            lora_backend=self.server_args.lora_backend,
589
590
591
        )
        logger.info("LoRA manager ready.")

592
    def profile_max_num_token(self, total_gpu_memory: int):
593
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
594
            self.device, self.gpu_id, distributed=self.tp_size > 1
595
        )
596
597
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
598
            and not self.server_args.disable_mla
599
600
601
602
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
603
                * torch._utils._element_size(self.kv_cache_dtype)
604
605
606
            )
        else:
            cell_size = (
607
                self.model_config.get_num_kv_heads(get_attention_tp_size())
608
609
610
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
611
                * torch._utils._element_size(self.kv_cache_dtype)
612
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
613
614
615
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
616
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
617
618
        return max_num_token

619
    def init_memory_pool(
620
621
        self,
        total_gpu_memory: int,
622
623
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
624
    ):
625
626
627
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
HAI's avatar
HAI committed
628
629
630
631
            if is_hip():  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
632
633
634
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
635
636
637
638
639
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

640
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
641
642
643
644
645
646
647
648
649
650
651
652

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

653
654
655
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

656
657
658
659
660
661
662
663
664
665
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
            else:
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
                    + max_num_reqs * self.server_args.speculative_num_steps
                    + 100
                )

666
667
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
668
                logging.warning(
669
670
671
672
673
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
674

675
        if self.max_total_num_tokens <= 0:
676
            raise RuntimeError(
677
                "Not enough memory. Please try to increase --mem-fraction-static."
678
            )
679

Liangsheng Yin's avatar
Liangsheng Yin committed
680
        self.req_to_token_pool = ReqToTokenPool(
681
682
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
683
            device=self.device,
684
            enable_memory_saver=self.server_args.enable_memory_saver,
Lianmin Zheng's avatar
Lianmin Zheng committed
685
        )
686

687
688
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
689
            and not self.server_args.disable_mla
690
691
692
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
693
                dtype=self.kv_cache_dtype,
694
695
696
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
697
                device=self.device,
698
                enable_memory_saver=self.server_args.enable_memory_saver,
699
            )
Shuo Yang's avatar
Shuo Yang committed
700
701
702
703
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
704
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
705
706
707
708
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
709
                enable_memory_saver=self.server_args.enable_memory_saver,
Shuo Yang's avatar
Shuo Yang committed
710
            )
711
712
713
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
714
                dtype=self.kv_cache_dtype,
715
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
716
717
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
718
                device=self.device,
719
                enable_memory_saver=self.server_args.enable_memory_saver,
720
            )
721
        logger.info(
722
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
723
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
724
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
725

Lianmin Zheng's avatar
Lianmin Zheng committed
726
727
728
729
730
731
732
733
734
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

735
736
737
738
739
740
741
742
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
743
            )
744
            assert not self.model_config.is_encoder_decoder, (
745
746
747
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
748
749
750
751
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
752
753
        elif self.server_args.attention_backend == "torch_native":
            self.attn_backend = TorchNativeAttnBackend(self)
754
755
        elif self.server_args.attention_backend == "flashinfer_mla":
            self.attn_backend = FlashInferMLAAttnBackend(self)
756
        else:
757
758
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
759
            )
760

Shuo Yang's avatar
Shuo Yang committed
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

778
    def init_cuda_graphs(self):
779
        """Capture cuda graphs."""
780
781
        self.cuda_graph_runner = None

782
783
784
785
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

786
787
        if self.server_args.disable_cuda_graph:
            return
788

789
        tic = time.time()
790
791
792
        logger.info(
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )
793
        self.cuda_graph_runner = CudaGraphRunner(self)
794
795
796
        logger.info(
            f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )
797

798
799
800
801
802
803
804
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

805
    def forward_decode(self, forward_batch: ForwardBatch):
806
        self.attn_backend.init_forward_metadata(forward_batch)
807
        return self.model.forward(
808
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
809
810
        )

811
    def forward_extend(self, forward_batch: ForwardBatch):
812
        self.attn_backend.init_forward_metadata(forward_batch)
813
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
814
815
816
817
818
819
820
821
822
823
824
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
825
826
827
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
828
829
830
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
831
832
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
833

Ke Bao's avatar
Ke Bao committed
834
835
836
837
838
    def forward_idle(self, forward_batch: ForwardBatch):
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

839
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
840
841
842
843
844
845
846
        if (
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
        ):
            return self.cuda_graph_runner.replay(forward_batch)

847
848
849
850
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
851
852
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
853
        else:
854
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
855

856
857
858
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
859
        # Apply logit bias
860
861
862
863
864
865
866
867
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
868
869
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
    def update_output_logprobs(
        self,
        logits_output: LogitsProcessorOutput,
        sampling_info: SamplingBatchInfo,
        top_logprobs_nums: List[int],
        token_ids_logprobs: List[int],
        next_token_ids: torch.Tensor,
        *,
        num_tokens_per_req: List[int],
    ):
        """Update the logits_output's output logprob based on next_token_ids

        Args:
            logits_output: The logits output from the model forward
            sampling_info: Sampling info for logprob calculation
            top_logprobs_nums: Number of logprobs per request.
            next_token_ids: Next token ids.
            num_tokens_per_req: The number of tokens per request.

        Returns:
            A list of next_token_ids
        """
        self._preprocess_logits(logits_output, sampling_info)
        # We should repeat top_logprobs_nums to match num_tokens_per_req.
        top_logprobs_nums_repeat_interleaved = []
        token_ids_logprobs_repeat_interleaved = []
        for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
            top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
        for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
            token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
        self.sampler(
            logits_output,
            sampling_info,
            True,
            top_logprobs_nums_repeat_interleaved,
            token_ids_logprobs_repeat_interleaved,
            batch_next_token_ids=next_token_ids,
        )

    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )

        self._preprocess_logits(logits_output, forward_batch.sampling_info)

932
933
934
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
935
            forward_batch.sampling_info,
936
937
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
938
            forward_batch.token_ids_logprobs,
939
        )
940
941
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
942
943
944
945
946
947
948
949
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972


def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
        return tensor.get(tp_rank)
    return tensor


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])