model_runner.py 40.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
Shuo Yang's avatar
Shuo Yang committed
18
import json
19
import logging
20
import os
21
import time
22
23
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25

import torch
26
import torch.distributed as dist
27
28
29
30
31

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
32
33
34
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
35
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
36
)
37
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
Shuo Yang's avatar
Shuo Yang committed
38
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
39
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
40
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
41
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
42
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
43
44
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
45
    get_attention_tp_size,
46
47
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
48
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
49
from sglang.srt.layers.sampler import Sampler
50
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
51
from sglang.srt.lora.lora_manager import LoRAManager
52
from sglang.srt.managers.schedule_batch import global_server_args_dict
53
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
54
    DoubleSparseTokenToKVPool,
55
56
57
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
58
    TokenToKVPoolAllocator,
59
)
Yineng Zhang's avatar
Yineng Zhang committed
60
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
61
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
62
from sglang.srt.model_loader import get_model
63
from sglang.srt.model_loader.weight_utils import default_weight_loader
64
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
65
from sglang.srt.server_args import ServerArgs
66
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
67
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
68
from sglang.srt.utils import (
69
    MultiprocessingSerializer,
70
    enable_show_time_cost,
71
    get_available_gpu_memory,
72
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
73
    is_cuda,
HAI's avatar
HAI committed
74
    is_hip,
75
    monkey_patch_p2p_access_check,
76
    monkey_patch_vllm_gguf_config,
77
    set_cpu_offload_max_bytes,
78
    set_cuda_arch,
79
)
80
from sglang.utils import get_exception_traceback
81

Ying Sheng's avatar
Ying Sheng committed
82
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
83

Lianmin Zheng's avatar
Lianmin Zheng committed
84

85
86
87
88
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300


Lianmin Zheng's avatar
Lianmin Zheng committed
89
class ModelRunner:
90
91
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
    def __init__(
        self,
94
        model_config: ModelConfig,
95
96
97
98
99
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
100
        server_args: ServerArgs,
101
        is_draft_worker: bool = False,
102
103
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
104
    ):
105
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
106
107
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
108
        self.device = server_args.device
109
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
110
111
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
112
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
113
        self.server_args = server_args
114
        self.is_draft_worker = is_draft_worker
115
116
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
117
        self.should_log = tp_rank == 0
118
119
120
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
121
122
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Ke Bao's avatar
Ke Bao committed
123

124
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
125
126
127
128
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
129
130
            # TODO: add MLA optimization on CPU
            if self.server_args.device != "cpu":
131
132
                if server_args.enable_flashinfer_mla:
                    logger.info(
133
                        "MLA optimization is turned on. Use flashinfer mla backend."
134
                    )
135
                    self.server_args.attention_backend = "flashinfer_mla"
136
137
138
                else:
                    logger.info("MLA optimization is turned on. Use triton backend.")
                    self.server_args.attention_backend = "triton"
Ke Bao's avatar
Ke Bao committed
139

Shuo Yang's avatar
Shuo Yang committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

154
        if self.is_multimodal:
Lianmin Zheng's avatar
Lianmin Zheng committed
155
            self.mem_fraction_static *= 0.95
156
157
158
159
160
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
            )

161
162
163
164
165
            if self.model_config.hf_config.architectures == [
                "MllamaForConditionalGeneration"
            ]:
                logger.info("Automatically turn off --chunked-prefill-size for mllama.")
                server_args.chunked_prefill_size = -1
166

Yineng Zhang's avatar
Yineng Zhang committed
167
168
169
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
170
                # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
171
172
173
174
                logger.info(
                    "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
                )
                server_args.chunked_prefill_size = -1
175
                server_args.disable_radix_cache = True
176

177
178
179
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
180
        if server_args.disable_outlines_disk_cache:
181
182
            from outlines.caching import disable_cache

183
184
            disable_cache()

185
186
        global_server_args_dict.update(
            {
187
188
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
189
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
190
                "disable_mla": server_args.disable_mla,
191
                "torchao_config": server_args.torchao_config,
192
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
193
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
194
                "enable_ep_moe": server_args.enable_ep_moe,
195
                "device": server_args.device,
196
197
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
198
                "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
199
                "disable_radix_cache": server_args.disable_radix_cache,
200
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
201
202
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
203
204
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
205

206
207
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

208
        # Get memory before model loading
209
        min_per_gpu_memory = self.init_torch_distributed()
210

211
212
213
214
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

215
        # Load the model
216
        self.sampler = Sampler()
217
        self.load_model()
218

219
220
221
222
223
224
225
226
227
228
229
230
        # Handle the case where some of models don't finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

231
        # Apply torchao quantization
232
233
234
235
236
237
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
238

239
        # Apply torch TP if the model supports it
240
241
242
243
244
245
246
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

247
        # Init memory pool and attention backends
248
249
        if server_args.lora_paths is not None:
            self.init_lora_manager()
250
251
        self.init_memory_pool(
            min_per_gpu_memory,
252
            server_args.max_running_requests,
253
254
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
255
256
257
258
259
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
260
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
261
            self.init_attention_backend()
262
263

    def init_torch_distributed(self):
264
        logger.info("Init torch distributed begin.")
265
        torch.get_device_module(self.device).set_device(self.gpu_id)
266

Zhang, Liangang's avatar
Zhang, Liangang committed
267
268
        if self.device == "cuda":
            backend = "nccl"
269
        elif self.device == "xpu":
270
            backend = "xccl"
271
272
        elif self.device == "hpu":
            backend = "hccl"
273
274
        elif self.device == "cpu":
            backend = "gloo"
275

276
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
277
        if not self.server_args.enable_p2p_check:
278
279
            monkey_patch_p2p_access_check()

280
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
281
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
282
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
283
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
284
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
285
286

        if not self.is_draft_worker:
Mick's avatar
Mick committed
287
            # Only initialize the distributed environment on the target model worker.
288
289
290
291
292
293
            init_distributed_environment(
                backend=backend,
                world_size=self.tp_size,
                rank=self.tp_rank,
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
294
                timeout=self.server_args.dist_timeout,
295
296
            )
            initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
297
298
299
300
301
302
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
            )
303

304
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
305
            self.device, self.gpu_id, distributed=self.tp_size > 1
306
        )
307
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
308
        self.tp_group = get_tp_group()
309
        self.attention_tp_group = get_attention_tp_group()
310

311
        # Check memory for tensor parallelism
312
        if self.tp_size > 1:
313
            if min_per_gpu_memory < local_gpu_memory * 0.9:
314
315
316
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
317

318
319
320
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
321
        return min_per_gpu_memory
322

Lianmin Zheng's avatar
Lianmin Zheng committed
323
    def load_model(self):
324
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
325
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
326
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
327
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
328
329

        # This can reduce thread conflicts and speed up weight loading.
330
331
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
332
333
334
335
336
337
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
338
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
339
340
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
341

342
343
        set_cuda_arch()

344
        # Prepare the model config
345
346
347
348
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
349
350
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
351
352

        # Load the model
353
354
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
355
356
357
358
359
360
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
361
        monkey_patch_vllm_parallel_state(reverse=True)
362

bjmsong's avatar
bjmsong committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

386
        # Parse other args
387
        self.sliding_window_size = (
388
389
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
390
391
            else None
        )
392
        self.dtype = self.model_config.dtype
393

394
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
395
        logger.info(
396
            f"Load weight end. "
397
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
398
            f"dtype={self.dtype}, "
399
400
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
401
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
402

403
404
405
406
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
407
        from sglang.srt.model_loader.loader import (
408
409
410
411
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
412
        from sglang.srt.model_loader.utils import set_default_torch_dtype
413
414

        logger.info(
Chayenne's avatar
Chayenne committed
415
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
416
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
417
418
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
419
        target_device = torch.device(self.device)
420
        self.model_config.model_path = model_path
421
422
423
424
425
        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
426
427
            message = f"Failed to get model loader: {loader}."
            return False, message
428
429
430

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
431
                DefaultModelLoader.Source(
432
                    config.model_path,
433
434
435
436
437
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
438
439
440
441
442
443
444
445
446
447
448
449
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

450
        with set_default_torch_dtype(self.model_config.dtype):
451
            try:
452
                iter = get_weight_iter(self.model_config)
453
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
454
                message = f"Failed to get weights iterator: {e}."
455
456
457
458
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
459
460
461
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
462
463
                del iter
                gc.collect()
464
                iter = get_weight_iter(self.model_config)
465
466
467
468
469
470
471
472
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

473
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
474
        return True, "Succeeded to update model weights."
475

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
504
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
570
        return True, "Success"
571

572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

589
590
591
592
593
594
595
596
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
597
            lora_backend=self.server_args.lora_backend,
598
599
600
        )
        logger.info("LoRA manager ready.")

601
    def profile_max_num_token(self, total_gpu_memory: int):
602
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
603
            self.device, self.gpu_id, distributed=self.tp_size > 1
604
        )
605
606
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
607
            and not self.server_args.disable_mla
608
609
610
611
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
612
                * torch._utils._element_size(self.kv_cache_dtype)
613
614
615
            )
        else:
            cell_size = (
616
                self.model_config.get_num_kv_heads(get_attention_tp_size())
617
618
619
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
620
                * torch._utils._element_size(self.kv_cache_dtype)
621
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
622
623
624
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
625
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
626
627
        return max_num_token

628
    def init_memory_pool(
629
630
        self,
        total_gpu_memory: int,
631
632
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
633
    ):
634
635
636
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
HAI's avatar
HAI committed
637
638
639
640
            if is_hip():  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
641
642
643
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
644
645
646
647
648
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

649
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
650
651
652
653
654
655
656
657
658
659
660
661

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

662
663
664
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

665
666
667
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
668
                max_num_reqs = self.server_args.max_num_reqs
669
            else:
670
671
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
672
673
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
674
675
676
677
678
679
680
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
681
682
                    + 100
                )
683
684
685
686
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
687

688
689
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
690
                logging.warning(
691
692
693
694
695
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
696

697
        if self.max_total_num_tokens <= 0:
698
            raise RuntimeError(
699
                "Not enough memory. Please try to increase --mem-fraction-static."
700
            )
701

702
703
704
705
706
707
708
709
710
711
712
        if self.req_to_token_pool is None:
            self.req_to_token_pool = ReqToTokenPool(
                size=max_num_reqs + 1,
                max_context_len=self.model_config.context_len + 4,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

713
714
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
715
            and not self.server_args.disable_mla
716
717
718
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
719
                dtype=self.kv_cache_dtype,
720
721
722
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
723
                device=self.device,
724
                enable_memory_saver=self.server_args.enable_memory_saver,
725
            )
Shuo Yang's avatar
Shuo Yang committed
726
727
728
729
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
730
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
731
732
733
734
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
735
                enable_memory_saver=self.server_args.enable_memory_saver,
Shuo Yang's avatar
Shuo Yang committed
736
            )
737
738
739
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
740
                dtype=self.kv_cache_dtype,
741
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
742
743
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
744
                device=self.device,
745
                enable_memory_saver=self.server_args.enable_memory_saver,
746
            )
747
748
749
750
751
752
753
754
755
756
757

        if self.token_to_kv_pool_allocator is None:
            self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                device=self.device,
                kvcache=self.token_to_kv_pool,
            )
        else:
            assert self.is_draft_worker

758
        logger.info(
759
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
760
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
761
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
762

Lianmin Zheng's avatar
Lianmin Zheng committed
763
764
765
766
767
768
769
770
771
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

772
773
774
775
776
777
778
779
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
780
            )
781
            assert not self.model_config.is_encoder_decoder, (
782
783
784
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
785
786
787
788
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
789
790
        elif self.server_args.attention_backend == "torch_native":
            self.attn_backend = TorchNativeAttnBackend(self)
791
792
        elif self.server_args.attention_backend == "flashinfer_mla":
            self.attn_backend = FlashInferMLAAttnBackend(self)
793
        else:
794
795
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
796
            )
797

Shuo Yang's avatar
Shuo Yang committed
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

815
    def init_cuda_graphs(self):
816
        """Capture cuda graphs."""
817
818
        self.cuda_graph_runner = None

819
820
821
822
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

823
824
        if self.server_args.disable_cuda_graph:
            return
825

826
        tic = time.time()
827
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
828
        logger.info(
829
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
830
        )
831
        self.cuda_graph_runner = CudaGraphRunner(self)
832
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
833
        logger.info(
834
835
            f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
            f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
836
        )
837

838
839
840
841
842
843
844
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

845
    def forward_decode(self, forward_batch: ForwardBatch):
846
        self.attn_backend.init_forward_metadata(forward_batch)
847
        return self.model.forward(
848
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
849
850
        )

851
852
853
854
855
856
    def forward_extend(
        self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
    ):
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

857
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
858
859
860
861
862
863
864
865
866
867
868
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
869
870
871
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
872
873
874
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
875
876
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
877

Ke Bao's avatar
Ke Bao committed
878
879
880
881
882
    def forward_idle(self, forward_batch: ForwardBatch):
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

883
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
884
885
886
887
888
889
890
        if (
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
        ):
            return self.cuda_graph_runner.replay(forward_batch)

891
892
893
894
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
895
896
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
897
        else:
898
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
899

900
901
902
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
903
        # Apply logit bias
904
905
906
907
908
909
910
911
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
912
913
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
    def update_output_logprobs(
        self,
        logits_output: LogitsProcessorOutput,
        sampling_info: SamplingBatchInfo,
        top_logprobs_nums: List[int],
        token_ids_logprobs: List[int],
        next_token_ids: torch.Tensor,
        *,
        num_tokens_per_req: List[int],
    ):
        """Update the logits_output's output logprob based on next_token_ids

        Args:
            logits_output: The logits output from the model forward
            sampling_info: Sampling info for logprob calculation
            top_logprobs_nums: Number of logprobs per request.
            next_token_ids: Next token ids.
            num_tokens_per_req: The number of tokens per request.

        Returns:
            A list of next_token_ids
        """
        self._preprocess_logits(logits_output, sampling_info)
        # We should repeat top_logprobs_nums to match num_tokens_per_req.
        top_logprobs_nums_repeat_interleaved = []
        token_ids_logprobs_repeat_interleaved = []
        for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
            top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
        for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
            token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
        self.sampler(
            logits_output,
            sampling_info,
            True,
            top_logprobs_nums_repeat_interleaved,
            token_ids_logprobs_repeat_interleaved,
            batch_next_token_ids=next_token_ids,
        )

    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )

        self._preprocess_logits(logits_output, forward_batch.sampling_info)

976
977
978
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
979
            forward_batch.sampling_info,
980
981
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
982
            forward_batch.token_ids_logprobs,
983
        )
984
985
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
986
987
988
989
990
991
992
993
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016


def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
        return tensor.get(tp_rank)
    return tensor


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])