model_runner.py 41.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
Shuo Yang's avatar
Shuo Yang committed
18
import json
19
import logging
20
import os
21
import time
22
23
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25

import torch
26
import torch.distributed as dist
27
28
29
30
31

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
32
33
34
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
35
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
36
)
37
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
38
39
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
40
    get_attention_tp_size,
41
42
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
43
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
45
from sglang.srt.layers.sampler import Sampler
46
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
47
from sglang.srt.lora.lora_manager import LoRAManager
48
from sglang.srt.managers.schedule_batch import global_server_args_dict
49
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
50
    DoubleSparseTokenToKVPool,
51
52
53
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
54
    TokenToKVPoolAllocator,
55
)
Lianmin Zheng's avatar
Lianmin Zheng committed
56
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
Yineng Zhang's avatar
Yineng Zhang committed
57
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
58
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
59
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
64
65
from sglang.srt.model_loader.loader import (
    DefaultModelLoader,
    device_loading_context,
    get_model_loader,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
66
from sglang.srt.model_loader.weight_utils import default_weight_loader
67
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
68
from sglang.srt.server_args import ServerArgs
69
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
70
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
71
from sglang.srt.utils import (
72
    MultiprocessingSerializer,
73
    enable_show_time_cost,
74
    get_available_gpu_memory,
75
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
76
    is_cuda,
HAI's avatar
HAI committed
77
    is_hip,
78
    monkey_patch_p2p_access_check,
79
    monkey_patch_vllm_gguf_config,
80
    set_cpu_offload_max_bytes,
81
    set_cuda_arch,
82
)
83

Ying Sheng's avatar
Ying Sheng committed
84
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
85

86
87
88
89
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300


Lianmin Zheng's avatar
Lianmin Zheng committed
90
class ModelRunner:
91
92
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
93
94
    def __init__(
        self,
95
        model_config: ModelConfig,
96
97
98
99
100
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
101
        server_args: ServerArgs,
102
        is_draft_worker: bool = False,
103
104
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
105
    ):
106
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
109
        self.device = server_args.device
110
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
111
112
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
113
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
114
        self.server_args = server_args
115
        self.is_draft_worker = is_draft_worker
116
117
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
118
        self.should_log = tp_rank == 0
119
120
121
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
122
        self.page_size = server_args.page_size
123
124
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Ke Bao's avatar
Ke Bao committed
125

126
        # Model-specific adjustment
127
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
128

129
130
        if server_args.show_time_cost:
            enable_show_time_cost()
131

132
        if server_args.disable_outlines_disk_cache:
133
134
            from outlines.caching import disable_cache

135
136
            disable_cache()

137
        # Global vars
138
139
        global_server_args_dict.update(
            {
140
141
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
142
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
143
                "disable_mla": server_args.disable_mla,
144
                "torchao_config": server_args.torchao_config,
145
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
146
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
147
                "enable_ep_moe": server_args.enable_ep_moe,
148
                "device": server_args.device,
149
150
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
151
                "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
152
                "disable_radix_cache": server_args.disable_radix_cache,
153
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
154
155
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
156
157
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
158

159
        # CPU offload
160
161
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

162
        # Get memory before model loading
163
        min_per_gpu_memory = self.init_torch_distributed()
164

165
166
167
168
169
        # If it is a draft model tp_group can be different.
        self.initialize(min_per_gpu_memory)

    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
170
171
172
173
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

174
        # Load the model
175
        self.sampler = Sampler()
176
        self.load_model()
177

178
        # Apply torchao quantization
179
180
181
182
183
184
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
185

186
        # Apply torch TP if the model supports it
187
188
189
190
191
192
193
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

194
        # Init lora
195
196
        if server_args.lora_paths is not None:
            self.init_lora_manager()
197
198

        # Init memory pool and attention backends
199
200
        self.init_memory_pool(
            min_per_gpu_memory,
201
            server_args.max_running_requests,
202
203
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
204
205
206
207
208
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
209
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
210
            self.init_attention_backend()
211

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    def model_specific_adjustment(self):
        server_args = self.server_args

        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not server_args.disable_mla
        ):
            # TODO: add MLA optimization on CPU
            if server_args.device != "cpu":
                if server_args.enable_flashinfer_mla:
                    logger.info(
                        "MLA optimization is turned on. Use flashinfer mla backend."
                    )
                    server_args.attention_backend = "flashinfer_mla"
                else:
                    logger.info("MLA optimization is turned on. Use triton backend.")
                    server_args.attention_backend = "triton"

        if server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
            self.mem_fraction_static *= 0.95
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
            )

            if self.model_config.hf_config.architectures == [
                "MllamaForConditionalGeneration"
            ]:
                logger.info("Automatically turn off --chunked-prefill-size for mllama.")
                server_args.chunked_prefill_size = -1

            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
                # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
                logger.info(
                    "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
                )
                server_args.chunked_prefill_size = -1
                server_args.disable_radix_cache = True

265
    def init_torch_distributed(self):
266
        logger.info("Init torch distributed begin.")
267

268
        torch.get_device_module(self.device).set_device(self.gpu_id)
Zhang, Liangang's avatar
Zhang, Liangang committed
269
270
        if self.device == "cuda":
            backend = "nccl"
271
        elif self.device == "xpu":
272
            backend = "xccl"
273
274
        elif self.device == "hpu":
            backend = "hccl"
275
276
        elif self.device == "cpu":
            backend = "gloo"
277

278
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
279
        if not self.server_args.enable_p2p_check:
280
281
            monkey_patch_p2p_access_check()

282
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
283
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
284
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
285
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
286
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
287
288

        if not self.is_draft_worker:
Mick's avatar
Mick committed
289
            # Only initialize the distributed environment on the target model worker.
290
291
292
293
294
295
            init_distributed_environment(
                backend=backend,
                world_size=self.tp_size,
                rank=self.tp_rank,
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
296
                timeout=self.server_args.dist_timeout,
297
298
            )
            initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
299
300
301
302
303
304
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
            )
305

306
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
307
            self.device, self.gpu_id, distributed=self.tp_size > 1
308
        )
309
        self.tp_group = get_tp_group()
310
        self.attention_tp_group = get_attention_tp_group()
311

312
        # Check memory for tensor parallelism
313
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
314
        if self.tp_size > 1:
315
            if min_per_gpu_memory < local_gpu_memory * 0.9:
316
                raise ValueError(
317
318
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                    f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
319
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
320

321
322
323
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
324
        return min_per_gpu_memory
325

Lianmin Zheng's avatar
Lianmin Zheng committed
326
    def load_model(self):
327
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
328
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
329
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
330
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
331
332

        # This can reduce thread conflicts and speed up weight loading.
333
334
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
335
336
337
338
339
340
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
341
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
342
343
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
344

345
346
        set_cuda_arch()

347
        # Prepare the model config
348
349
350
351
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
352
353
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
354
355

        # Load the model
356
357
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
358
359
        monkey_patch_isinstance_for_vllm_base_layer()

360
361
362
363
364
365
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
366
        monkey_patch_vllm_parallel_state(reverse=True)
367
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
368

bjmsong's avatar
bjmsong committed
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

392
        # Parse other args
393
        self.sliding_window_size = (
394
395
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
396
397
            else None
        )
398
        self.dtype = self.model_config.dtype
399

400
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
401
        logger.info(
402
            f"Load weight end. "
403
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
404
            f"dtype={self.dtype}, "
405
406
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
407
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
408

409
410
411
412
413
414
415
416
417
418
419
420
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

421
422
423
424
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
425
        logger.info(
Chayenne's avatar
Chayenne committed
426
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
427
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
428
429
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
430
        target_device = torch.device(self.device)
431
        self.model_config.model_path = model_path
432
433
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
434
        # Only support DefaultModelLoader for now
435
436
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
437
438
            message = f"Failed to get model loader: {loader}."
            return False, message
439
440
441

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
442
                DefaultModelLoader.Source(
443
                    config.model_path,
444
445
446
447
448
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
449
450
451
452
453
454
455
456
457
458
459
460
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

461
        with set_default_torch_dtype(self.model_config.dtype):
462
            try:
463
                iter = get_weight_iter(self.model_config)
464
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
465
                message = f"Failed to get weights iterator: {e}."
466
467
468
469
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
470
471
472
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
473
474
                del iter
                gc.collect()
475
                iter = get_weight_iter(self.model_config)
476
477
478
479
480
481
482
483
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

484
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
485
        return True, "Succeeded to update model weights."
486

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
515
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
581
        return True, "Success"
582

583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

600
601
602
603
604
605
606
607
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
608
            lora_backend=self.server_args.lora_backend,
609
610
611
        )
        logger.info("LoRA manager ready.")

612
    def profile_max_num_token(self, total_gpu_memory: int):
613
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
614
            self.device, self.gpu_id, distributed=self.tp_size > 1
615
        )
616
617
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
618
            and not self.server_args.disable_mla
619
620
621
622
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
623
                * torch._utils._element_size(self.kv_cache_dtype)
624
625
626
            )
        else:
            cell_size = (
627
                self.model_config.get_num_kv_heads(get_attention_tp_size())
628
629
630
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
631
                * torch._utils._element_size(self.kv_cache_dtype)
632
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
633
634
635
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
636
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
637
638
        return max_num_token

639
    def init_memory_pool(
640
641
        self,
        total_gpu_memory: int,
642
643
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
644
    ):
645
646
647
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
648
            if is_hip():  # Using natively supported format
HAI's avatar
HAI committed
649
650
651
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
652
653
654
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
655
656
657
658
659
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

660
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
661
662
663
664
665
666
667
668
669
670
671
672

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

673
674
675
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

676
677
678
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
679
                max_num_reqs = self.server_args.max_num_reqs
680
            else:
681
682
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
683
684
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
685
686
687
688
689
690
691
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
692
693
                    + 100
                )
694
695
696
697
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
698

699
700
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
701
                logging.warning(
702
703
704
705
706
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
707

708
709
710
711
712
713
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )

714
        if self.max_total_num_tokens <= 0:
715
            raise RuntimeError(
716
                "Not enough memory. Please try to increase --mem-fraction-static."
717
            )
718

719
720
721
722
723
724
725
726
727
728
729
        if self.req_to_token_pool is None:
            self.req_to_token_pool = ReqToTokenPool(
                size=max_num_reqs + 1,
                max_context_len=self.model_config.context_len + 4,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

730
731
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
732
            and not self.server_args.disable_mla
733
734
735
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
736
                page_size=self.page_size,
737
                dtype=self.kv_cache_dtype,
738
739
740
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
741
                device=self.device,
742
                enable_memory_saver=self.server_args.enable_memory_saver,
743
            )
Shuo Yang's avatar
Shuo Yang committed
744
745
746
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
747
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
748
                dtype=self.kv_cache_dtype,
749
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
750
751
752
753
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
754
                enable_memory_saver=self.server_args.enable_memory_saver,
Shuo Yang's avatar
Shuo Yang committed
755
            )
756
757
758
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
759
                page_size=self.page_size,
760
                dtype=self.kv_cache_dtype,
761
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
762
763
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
764
                device=self.device,
765
                enable_memory_saver=self.server_args.enable_memory_saver,
766
            )
767
768

        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
            if self.page_size == 1:
                self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
            else:
                self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
784
785
786
        else:
            assert self.is_draft_worker

787
        logger.info(
788
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
789
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
790
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
791

Lianmin Zheng's avatar
Lianmin Zheng committed
792
793
794
795
796
797
798
799
800
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

801
802
    def init_attention_backend(self):
        """Init attention kernel backend."""
803
        if self.server_args.attention_backend == "flashinfer":
804
805
806
807
            from sglang.srt.layers.attention.flashinfer_backend import (
                FlashInferAttnBackend,
            )

808
809
810
811
812
813
814
815
816
817
818
819
820
821
            # Init streams
            if self.server_args.speculative_algorithm == "EAGLE":
                self.plan_stream_for_flashinfer = torch.cuda.Stream()
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
822
823
824
825
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

826
                self.attn_backend = DoubleSparseAttnBackend(self)
827
            else:
828
829
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

830
831
                self.attn_backend = TritonAttnBackend(self)
        elif self.server_args.attention_backend == "torch_native":
832
833
834
835
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

836
837
            self.attn_backend = TorchNativeAttnBackend(self)
        elif self.server_args.attention_backend == "flashinfer_mla":
838
839
840
841
            from sglang.srt.layers.attention.flashinfer_mla_backend import (
                FlashInferMLAAttnBackend,
            )

842
843
844
845
846
            self.attn_backend = FlashInferMLAAttnBackend(self)
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
847

Shuo Yang's avatar
Shuo Yang committed
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

865
    def init_cuda_graphs(self):
866
        """Capture cuda graphs."""
867
868
        self.cuda_graph_runner = None

869
870
871
872
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

873
874
        if self.server_args.disable_cuda_graph:
            return
875

876
        tic = time.time()
877
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
878
        logger.info(
879
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
880
        )
881
        self.cuda_graph_runner = CudaGraphRunner(self)
882
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
883
        logger.info(
884
885
            f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
            f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
886
        )
887

888
889
890
891
892
893
894
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

895
    def forward_decode(self, forward_batch: ForwardBatch):
896
        self.attn_backend.init_forward_metadata(forward_batch)
897
        return self.model.forward(
898
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
899
900
        )

901
902
903
904
905
906
    def forward_extend(
        self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
    ):
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

907
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
908
909
910
911
912
913
914
915
916
917
918
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
919
920
921
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
922
923
924
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
925
926
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
927

Ke Bao's avatar
Ke Bao committed
928
929
930
931
932
    def forward_idle(self, forward_batch: ForwardBatch):
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

933
934
935
    def forward(
        self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
    ) -> LogitsProcessorOutput:
936
937
938
939
940
        if (
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
        ):
941
942
943
            return self.cuda_graph_runner.replay(
                forward_batch, skip_attn_backend_init=skip_attn_backend_init
            )
944

945
946
947
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
948
949
950
            return self.forward_extend(
                forward_batch, skip_attn_backend_init=skip_attn_backend_init
            )
Ke Bao's avatar
Ke Bao committed
951
952
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
953
        else:
954
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
955

956
957
958
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
959
        # Apply logit bias
960
961
962
963
964
965
966
967
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
968
969
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )

        self._preprocess_logits(logits_output, forward_batch.sampling_info)

993
994
995
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
996
            forward_batch.sampling_info,
997
998
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
999
            forward_batch.token_ids_logprobs,
1000
        )
1001
1002
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1003
1004
1005
1006
1007
1008
1009
1010
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"
1011

1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
        return tensor.get(tp_rank)
    return tensor


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])