model_runner.py 43.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import datetime
17
import gc
Shuo Yang's avatar
Shuo Yang committed
18
import json
19
import logging
20
import os
21
import time
22
23
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25

import torch
26
import torch.distributed as dist
27
28
29
30
31

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
32
33
34
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
35
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
36
)
37
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
38
39
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
40
    get_attention_tp_size,
41
42
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
43
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
45
from sglang.srt.layers.sampler import Sampler
46
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
47
from sglang.srt.lora.lora_manager import LoRAManager
48
from sglang.srt.managers.schedule_batch import global_server_args_dict
49
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
50
    DoubleSparseTokenToKVPool,
51
52
53
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
54
    TokenToKVPoolAllocator,
55
)
Lianmin Zheng's avatar
Lianmin Zheng committed
56
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
Yineng Zhang's avatar
Yineng Zhang committed
57
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
58
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
59
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
64
65
from sglang.srt.model_loader.loader import (
    DefaultModelLoader,
    device_loading_context,
    get_model_loader,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
66
from sglang.srt.model_loader.weight_utils import default_weight_loader
67
from sglang.srt.patch_torch import monkey_patch_torch_reductions
68
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
69
from sglang.srt.server_args import ServerArgs
70
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
71
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
72
from sglang.srt.utils import (
73
    MultiprocessingSerializer,
74
    enable_show_time_cost,
75
    get_available_gpu_memory,
76
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
77
    is_cuda,
HAI's avatar
HAI committed
78
    is_hip,
79
    monkey_patch_p2p_access_check,
80
    monkey_patch_vllm_gguf_config,
81
    set_cpu_offload_max_bytes,
82
    set_cuda_arch,
83
)
84

Ying Sheng's avatar
Ying Sheng committed
85
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
86

87
88
89
90
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300


Lianmin Zheng's avatar
Lianmin Zheng committed
91
class ModelRunner:
92
93
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
    def __init__(
        self,
96
        model_config: ModelConfig,
97
98
99
100
101
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
102
        server_args: ServerArgs,
103
        is_draft_worker: bool = False,
104
105
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
106
    ):
107
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
110
        self.device = server_args.device
111
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
112
113
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
114
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
115
        self.server_args = server_args
116
        self.is_draft_worker = is_draft_worker
117
118
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
119
        self.should_log = tp_rank == 0
120
121
122
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
123
        self.page_size = server_args.page_size
124
125
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Ke Bao's avatar
Ke Bao committed
126

127
        # Model-specific adjustment
128
        self.model_specific_adjustment()
Shuo Yang's avatar
Shuo Yang committed
129

130
131
        if server_args.show_time_cost:
            enable_show_time_cost()
132

133
        if server_args.disable_outlines_disk_cache:
134
135
            from outlines.caching import disable_cache

136
137
            disable_cache()

138
        # Global vars
139
140
        global_server_args_dict.update(
            {
141
142
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
143
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
144
                "disable_mla": server_args.disable_mla,
145
                "torchao_config": server_args.torchao_config,
146
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
147
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
148
                "enable_ep_moe": server_args.enable_ep_moe,
149
                "enable_deepep_moe": server_args.enable_deepep_moe,
150
                "device": server_args.device,
151
152
                "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
                "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
153
                "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
lukec's avatar
lukec committed
154
                "enable_flashmla": server_args.enable_flashmla,
155
                "disable_radix_cache": server_args.disable_radix_cache,
156
                "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
157
158
                "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
                "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
159
160
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
161

162
        # CPU offload
163
164
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

165
        # Get memory before model loading
166
        min_per_gpu_memory = self.init_torch_distributed()
167

168
169
170
171
172
        # If it is a draft model tp_group can be different.
        self.initialize(min_per_gpu_memory)

    def initialize(self, min_per_gpu_memory: float):
        server_args = self.server_args
173
174
175
176
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

177
        # Load the model
178
        self.sampler = Sampler()
179
        self.load_model()
180

181
        # Apply torchao quantization
182
183
184
185
186
187
        torchao_applied = getattr(self.model, "torchao_applied", False)
        # In layered loading, torchao may have been applied
        if not torchao_applied:
            apply_torchao_config_to_model(
                self.model, global_server_args_dict["torchao_config"]
            )
188

189
        # Apply torch TP if the model supports it
190
191
192
193
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()

194
        # Init lora
195
196
        if server_args.lora_paths is not None:
            self.init_lora_manager()
197
198

        # Init memory pool and attention backends
199
200
        self.init_memory_pool(
            min_per_gpu_memory,
201
            server_args.max_running_requests,
202
203
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
204
205
206
207
208
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
209
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
210
            self.init_attention_backend()
211

James Liu's avatar
James Liu committed
212
213
214
215
        # auxiliary hidden capture mode. TODO: expose this to server args?
        if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
            self.model.set_eagle3_layers_to_capture()

216
217
218
219
220
221
222
223
224
225
226
227
228
229
    def model_specific_adjustment(self):
        server_args = self.server_args

        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not server_args.disable_mla
        ):
            # TODO: add MLA optimization on CPU
            if server_args.device != "cpu":
                if server_args.enable_flashinfer_mla:
                    logger.info(
                        "MLA optimization is turned on. Use flashinfer mla backend."
                    )
                    server_args.attention_backend = "flashinfer_mla"
lukec's avatar
lukec committed
230
231
232
                elif server_args.enable_flashmla:
                    logger.info("MLA optimization is turned on. Use flashmla decode.")
                    server_args.attention_backend = "flashmla"
233
234
235
236
                elif server_args.attention_backend == "fa3":
                    logger.info(
                        f"MLA optimization is turned on. Use flash attention 3 backend."
                    )
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
                else:
                    logger.info("MLA optimization is turned on. Use triton backend.")
                    server_args.attention_backend = "triton"

        if server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            server_args.attention_backend = "triton"
            server_args.disable_cuda_graph = True
            if server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)

        if self.is_multimodal:
            self.mem_fraction_static *= 0.95
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
            )

            if self.model_config.hf_config.architectures == [
                "MllamaForConditionalGeneration"
            ]:
                logger.info("Automatically turn off --chunked-prefill-size for mllama.")
                server_args.chunked_prefill_size = -1

            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
268
269
            ] or self.model_config.hf_config.architectures == [
                "Qwen2_5_VLForConditionalGeneration"
270
            ]:
271
                # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
272
                logger.info(
273
                    "Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
274
275
276
277
                )
                server_args.chunked_prefill_size = -1
                server_args.disable_radix_cache = True

278
279
280
            if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
                # TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
                logger.info(
281
                    "Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
282
283
284
285
                )
                server_args.chunked_prefill_size = -1
                server_args.disable_radix_cache = True

286
287
288
        if server_args.enable_deepep_moe:
            logger.info("DeepEP is turned on.")

289
    def init_torch_distributed(self):
290
        logger.info("Init torch distributed begin.")
291

292
293
294
295
296
297
298
299
        try:
            torch.get_device_module(self.device).set_device(self.gpu_id)
        except Exception:
            logger.warning(
                f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
            )
            raise

Zhang, Liangang's avatar
Zhang, Liangang committed
300
301
        if self.device == "cuda":
            backend = "nccl"
302
        elif self.device == "xpu":
303
            backend = "xccl"
304
305
        elif self.device == "hpu":
            backend = "hccl"
306
307
        elif self.device == "cpu":
            backend = "gloo"
308

309
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
310
        if not self.server_args.enable_p2p_check:
311
312
            monkey_patch_p2p_access_check()

313
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
314
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
315
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
316
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
317
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
318
319

        if not self.is_draft_worker:
Mick's avatar
Mick committed
320
            # Only initialize the distributed environment on the target model worker.
321
322
323
324
325
326
            init_distributed_environment(
                backend=backend,
                world_size=self.tp_size,
                rank=self.tp_rank,
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
327
                timeout=self.server_args.dist_timeout,
328
329
            )
            initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
330
331
332
333
334
335
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
            )
336

337
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
338
            self.device, self.gpu_id, distributed=self.tp_size > 1
339
        )
340
        self.tp_group = get_tp_group()
341
        self.attention_tp_group = get_attention_tp_group()
342

343
        # Check memory for tensor parallelism
344
        local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
345
        if self.tp_size > 1:
346
            if min_per_gpu_memory < local_gpu_memory * 0.9:
347
                raise ValueError(
348
349
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
                    f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
350
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
351

352
353
354
        logger.info(
            f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
        )
355
        return min_per_gpu_memory
356

Lianmin Zheng's avatar
Lianmin Zheng committed
357
    def load_model(self):
358
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
359
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
360
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
361
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
362
363

        # This can reduce thread conflicts and speed up weight loading.
364
365
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
366
367
368
369
370
371
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
372
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
373
374
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
375

376
377
        set_cuda_arch()

378
        # Prepare the model config
379
380
381
382
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
383
384
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
385
386

        # Load the model
387
388
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
389
390
        monkey_patch_isinstance_for_vllm_base_layer()

391
392
393
394
395
396
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
397
        monkey_patch_vllm_parallel_state(reverse=True)
398
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
399

bjmsong's avatar
bjmsong committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

423
        # Parse other args
424
        self.sliding_window_size = (
425
426
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
427
428
            else None
        )
429
        self.dtype = self.model_config.dtype
430

431
        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
432
        logger.info(
433
            f"Load weight end. "
434
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
435
            f"dtype={self.dtype}, "
436
437
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
438
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
439

440
441
442
443
444
445
446
447
448
449
450
451
        # Handle the case where some ranks do not finish loading.
        try:
            dist.monitored_barrier(
                group=get_tp_group().cpu_group,
                timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
                wait_all_ranks=True,
            )
        except RuntimeError:
            raise ValueError(
                f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
            ) from None

452
453
454
455
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
456
        logger.info(
Chayenne's avatar
Chayenne committed
457
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
458
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
459
460
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
461
        target_device = torch.device(self.device)
462
        self.model_config.model_path = model_path
463
464
        load_config = LoadConfig(load_format=load_format)

Lianmin Zheng's avatar
Lianmin Zheng committed
465
        # Only support DefaultModelLoader for now
466
467
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
468
469
            message = f"Failed to get model loader: {loader}."
            return False, message
470
471
472

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
473
                DefaultModelLoader.Source(
474
                    config.model_path,
475
476
477
478
479
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
480
481
482
483
484
485
486
487
488
489
490
491
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

492
        with set_default_torch_dtype(self.model_config.dtype):
493
            try:
494
                iter = get_weight_iter(self.model_config)
495
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
496
                message = f"Failed to get weights iterator: {e}."
497
498
499
500
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
501
502
503
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
504
505
                del iter
                gc.collect()
506
                iter = get_weight_iter(self.model_config)
507
508
509
510
511
512
513
514
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

515
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
516
        return True, "Succeeded to update model weights."
517

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
546
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
    def update_weights_from_tensor(
        self,
        named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
        load_format: Optional[str] = None,
    ):
        named_tensors = [
            (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
            for name, tensor in named_tensors
        ]
        if load_format == "direct":
            _model_load_weights_direct(self.model, named_tensors)
        elif load_format is None:
            self.model.load_weights(named_tensors)
        else:
            raise NotImplementedError(f"Unknown load_format={load_format}")
612
        return True, "Success"
613

614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

631
632
633
634
635
636
637
638
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
639
            lora_backend=self.server_args.lora_backend,
640
641
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
642
643
644
        )
        logger.info("LoRA manager ready.")

645
    def profile_max_num_token(self, total_gpu_memory: int):
646
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
647
            self.device, self.gpu_id, distributed=self.tp_size > 1
648
        )
649
650
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
651
            and not self.server_args.disable_mla
652
653
654
655
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
656
                * torch._utils._element_size(self.kv_cache_dtype)
657
658
659
            )
        else:
            cell_size = (
660
                self.model_config.get_num_kv_heads(get_attention_tp_size())
661
662
663
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
664
                * torch._utils._element_size(self.kv_cache_dtype)
665
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
666
667
668
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
669
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
670
671
        return max_num_token

672
    def init_memory_pool(
673
674
        self,
        total_gpu_memory: int,
675
676
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
677
    ):
678
679
680
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
681
            if is_hip():  # Using natively supported format
HAI's avatar
HAI committed
682
683
684
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
685
686
687
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
688
689
690
691
692
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

693
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
694
695
696
697
698
699
700
701
702
703
704
705

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

706
707
708
        if SGLANG_CI_SMALL_KV_SIZE:
            self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)

709
710
711
        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
712
                max_num_reqs = self.server_args.max_num_reqs
713
            else:
714
715
                # We are sharing the `token_to_kv_pool`, and both verify and draft tokens
                # can be concurrently allocated, so we should give a headroom for it.
716
717
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
718
719
720
721
722
723
724
                    # draft
                    + max_num_reqs
                    * self.server_args.speculative_num_steps
                    * self.server_args.speculative_eagle_topk
                    # verify
                    + max_num_reqs * self.server_args.speculative_num_draft_tokens
                    # buffer
725
726
                    + 100
                )
727
728
729
730
                # Target worker and draft worker shares the same indices for the
                # token_to_kv_pool, so we should make sure to match max_total_num_tokens.
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
                self.server_args.max_num_reqs = max_num_reqs
731

732
733
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
734
                logging.warning(
735
736
737
738
739
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
740

741
742
743
744
745
746
        self.max_total_num_tokens = (
            self.max_total_num_tokens
            // self.server_args.page_size
            * self.server_args.page_size
        )

747
        if self.max_total_num_tokens <= 0:
748
            raise RuntimeError(
749
                "Not enough memory. Please try to increase --mem-fraction-static."
750
            )
751

752
753
754
755
756
757
758
759
760
761
762
        if self.req_to_token_pool is None:
            self.req_to_token_pool = ReqToTokenPool(
                size=max_num_reqs + 1,
                max_context_len=self.model_config.context_len + 4,
                device=self.device,
                enable_memory_saver=self.server_args.enable_memory_saver,
            )
        else:
            # Draft worker shares req_to_token_pool with the target worker.
            assert self.is_draft_worker

763
764
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
765
            and not self.server_args.disable_mla
766
767
768
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
769
                page_size=self.page_size,
770
                dtype=self.kv_cache_dtype,
771
772
773
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
774
                device=self.device,
775
                enable_memory_saver=self.server_args.enable_memory_saver,
776
            )
Shuo Yang's avatar
Shuo Yang committed
777
778
779
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
780
                page_size=self.page_size,
Shuo Yang's avatar
Shuo Yang committed
781
                dtype=self.kv_cache_dtype,
782
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
783
784
785
786
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
787
                enable_memory_saver=self.server_args.enable_memory_saver,
Shuo Yang's avatar
Shuo Yang committed
788
            )
789
790
791
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
792
                page_size=self.page_size,
793
                dtype=self.kv_cache_dtype,
794
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
795
796
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
797
                device=self.device,
798
                enable_memory_saver=self.server_args.enable_memory_saver,
799
            )
800
801

        if self.token_to_kv_pool_allocator is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
            if self.page_size == 1:
                self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
            else:
                self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
                    self.max_total_num_tokens,
                    page_size=self.page_size,
                    dtype=self.kv_cache_dtype,
                    device=self.device,
                    kvcache=self.token_to_kv_pool,
                )
817
818
819
        else:
            assert self.is_draft_worker

820
        logger.info(
821
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
822
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
823
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
824

Lianmin Zheng's avatar
Lianmin Zheng committed
825
826
827
828
829
830
831
832
833
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

834
835
    def init_attention_backend(self):
        """Init attention kernel backend."""
836
        if self.server_args.attention_backend == "flashinfer":
837
838
839
840
            from sglang.srt.layers.attention.flashinfer_backend import (
                FlashInferAttnBackend,
            )

841
842
843
844
845
846
847
848
849
850
851
852
853
854
            # Init streams
            if self.server_args.speculative_algorithm == "EAGLE":
                self.plan_stream_for_flashinfer = torch.cuda.Stream()
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            assert not self.model_config.is_encoder_decoder, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
            if self.server_args.enable_double_sparsity:
855
856
857
858
                from sglang.srt.layers.attention.double_sparsity_backend import (
                    DoubleSparseAttnBackend,
                )

859
                self.attn_backend = DoubleSparseAttnBackend(self)
860
            else:
861
862
                from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

863
864
                self.attn_backend = TritonAttnBackend(self)
        elif self.server_args.attention_backend == "torch_native":
865
866
867
868
            from sglang.srt.layers.attention.torch_native_backend import (
                TorchNativeAttnBackend,
            )

869
870
            self.attn_backend = TorchNativeAttnBackend(self)
        elif self.server_args.attention_backend == "flashinfer_mla":
871
872
873
874
            from sglang.srt.layers.attention.flashinfer_mla_backend import (
                FlashInferMLAAttnBackend,
            )

875
            self.attn_backend = FlashInferMLAAttnBackend(self)
lukec's avatar
lukec committed
876
877
878
879
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

            self.attn_backend = FlashMLABackend(self)
880
881
882
883
884
885
        elif self.server_args.attention_backend == "fa3":
            assert torch.cuda.get_device_capability()[0] >= 9, (
                "FlashAttention v3 Backend requires SM>=90. "
                "Please use `--attention-backend flashinfer`."
            )
            logger.warning(
886
                "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
887
888
889
890
891
892
            )
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
            )

            self.attn_backend = FlashAttentionBackend(self)
893
894
895
896
        else:
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
            )
897

Shuo Yang's avatar
Shuo Yang committed
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

915
    def init_cuda_graphs(self):
916
        """Capture cuda graphs."""
917
918
        self.cuda_graph_runner = None

919
920
921
922
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

923
924
        if self.server_args.disable_cuda_graph:
            return
925

926
        tic = time.time()
927
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
928
        logger.info(
929
            f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
930
        )
931
        self.cuda_graph_runner = CudaGraphRunner(self)
932
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
933
        logger.info(
934
935
            f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
            f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
936
        )
937

938
939
940
941
942
943
944
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

945
    def forward_decode(self, forward_batch: ForwardBatch):
946
        self.attn_backend.init_forward_metadata(forward_batch)
947
        return self.model.forward(
948
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
949
950
        )

951
952
953
954
955
956
    def forward_extend(
        self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
    ):
        if not skip_attn_backend_init:
            self.attn_backend.init_forward_metadata(forward_batch)

957
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
958
959
960
961
962
963
964
965
966
967
968
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
969
970
971
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
972
973
974
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
975
976
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
977

Ke Bao's avatar
Ke Bao committed
978
979
980
981
982
    def forward_idle(self, forward_batch: ForwardBatch):
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

983
984
985
    def forward(
        self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
    ) -> LogitsProcessorOutput:
986
987
988
989
990
        if (
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
        ):
991
992
993
            return self.cuda_graph_runner.replay(
                forward_batch, skip_attn_backend_init=skip_attn_backend_init
            )
994

995
996
997
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
998
999
1000
            return self.forward_extend(
                forward_batch, skip_attn_backend_init=skip_attn_backend_init
            )
Ke Bao's avatar
Ke Bao committed
1001
1002
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1003
        else:
1004
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1005

1006
1007
1008
    def _preprocess_logits(
        self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
    ):
1009
        # Apply logit bias
1010
1011
1012
1013
1014
1015
1016
1017
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
1018
1019
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
    def sample(
        self,
        logits_output: LogitsProcessorOutput,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        """Sample and compute logprobs and update logits_output.

        Args:
            logits_output: The logits output from the model forward
            forward_batch: The forward batch that generates logits_output

        Returns:
            A list of next_token_ids
        """
        # For duplex models with multiple output streams.
        if isinstance(logits_output, tuple):
            return torch.stack(
                [self.sample(values, forward_batch) for values in logits_output],
                axis=-1,
            )

        self._preprocess_logits(logits_output, forward_batch.sampling_info)

1043
1044
1045
        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
1046
            forward_batch.sampling_info,
1047
1048
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
1049
            forward_batch.token_ids_logprobs,
1050
        )
1051
1052
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
1053
1054
1055
1056
1057
1058
1059
1060
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"
1061

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    def save_remote_model(self, url: str):
        from sglang.srt.model_loader.loader import RemoteModelLoader

        logger.info(f"Saving model to {url}")
        RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)

    def save_sharded_model(
        self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
    ):
        from sglang.srt.model_loader.loader import ShardedStateLoader

        logger.info(
            f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
        )
        ShardedStateLoader.save_model(self.model, path, pattern, max_size)

1078
1079
1080
1081
1082
1083
1084
1085
1086

def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
    params_dict = dict(model.named_parameters())
    for name, tensor in named_tensors:
        default_weight_loader(params_dict[name], tensor)


def _unwrap_tensor(tensor, tp_rank):
    if isinstance(tensor, LocalSerializedTensor):
1087
1088
1089
        monkey_patch_torch_reductions()
        tensor = tensor.get(tp_rank)
    return tensor.to(torch.cuda.current_device())
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100


@dataclass
class LocalSerializedTensor:
    """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
    The i-th element in the list corresponds to i-th rank's GPU."""

    values: List[bytes]

    def get(self, rank: int):
        return MultiprocessingSerializer.deserialize(self.values[rank])