model_runner.py 32.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import gc
Shuo Yang's avatar
Shuo Yang committed
17
import json
18
import logging
19
import time
20
from typing import List, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22

import torch
23
import torch.distributed as dist
24
25
26
27
28

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
zhyncs's avatar
zhyncs committed
29
30
31
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
32
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
33
)
34
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
Shuo Yang's avatar
Shuo Yang committed
35
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
36
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
37
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
38
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
39
40
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
41
    get_attention_tp_size,
42
43
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
44
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
45
from sglang.srt.layers.sampler import Sampler
46
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
47
from sglang.srt.lora.lora_manager import LoRAManager
48
from sglang.srt.managers.schedule_batch import global_server_args_dict
49
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
50
    DoubleSparseTokenToKVPool,
51
52
53
54
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
55
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
57
from sglang.srt.server_args import ServerArgs
58
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
59
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
60
from sglang.srt.utils import (
61
    enable_show_time_cost,
62
    get_available_gpu_memory,
63
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
64
    is_cuda,
HAI's avatar
HAI committed
65
    is_hip,
66
    monkey_patch_vllm_gguf_config,
67
    monkey_patch_vllm_p2p_access_check,
68
    set_cpu_offload_max_bytes,
69
)
70

Ying Sheng's avatar
Ying Sheng committed
71
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
72

Lianmin Zheng's avatar
Lianmin Zheng committed
73
74

class ModelRunner:
75
76
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
    def __init__(
        self,
79
        model_config: ModelConfig,
80
81
82
83
84
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
85
        server_args: ServerArgs,
86
        is_draft_worker: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
87
    ):
88
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
89
90
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
91
        self.device = server_args.device
92
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
93
94
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
95
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
96
        self.server_args = server_args
97
        self.is_draft_worker = is_draft_worker
98
99
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
100
        self.should_log = tp_rank == 0
101
102
103
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
Ke Bao's avatar
Ke Bao committed
104

105
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
106
107
108
109
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
110
111
112
113
            # TODO: add MLA optimization on CPU
            if self.server_args.device != "cpu":
                logger.info("MLA optimization is turned on. Use triton backend.")
                self.server_args.attention_backend = "triton"
Ke Bao's avatar
Ke Bao committed
114

Shuo Yang's avatar
Shuo Yang committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

129
        if self.is_multimodal:
Lianmin Zheng's avatar
Lianmin Zheng committed
130
            self.mem_fraction_static *= 0.95
131
132
133
134
135
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
            )

136
137
138
139
140
            if self.model_config.hf_config.architectures == [
                "MllamaForConditionalGeneration"
            ]:
                logger.info("Automatically turn off --chunked-prefill-size for mllama.")
                server_args.chunked_prefill_size = -1
141

Yineng Zhang's avatar
Yineng Zhang committed
142
143
144
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
145
                # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
146
147
148
149
                logger.info(
                    "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
                )
                server_args.chunked_prefill_size = -1
150
                server_args.disable_radix_cache = True
151

152
153
154
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
155
        if server_args.disable_outlines_disk_cache:
156
157
            from outlines.caching import disable_cache

158
159
            disable_cache()

160
161
        global_server_args_dict.update(
            {
162
163
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
164
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
165
                "disable_mla": server_args.disable_mla,
166
                "torchao_config": server_args.torchao_config,
167
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
168
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
169
                "enable_ep_moe": server_args.enable_ep_moe,
170
                "device": server_args.device,
171
172
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
173

174
175
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

176
        # Get memory before model loading
177
        min_per_gpu_memory = self.init_torch_distributed()
178

179
180
181
182
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

183
        # Load the model
184
        self.sampler = Sampler()
185
        self.load_model()
186

187
188
189
190
191
        # Apply torchao quantization
        apply_torchao_config_to_model(
            self.model, global_server_args_dict["torchao_config"]
        )

192
        # Apply torch TP if the model supports it
193
194
195
196
197
198
199
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

200
        # Init memory pool and attention backends
201
202
        if server_args.lora_paths is not None:
            self.init_lora_manager()
203
204
        self.init_memory_pool(
            min_per_gpu_memory,
205
            server_args.max_running_requests,
206
207
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
208
209
210
211
212
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
213
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
214
            self.init_attention_backend()
215
216

    def init_torch_distributed(self):
217
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
218
        # Init torch distributed
219
        torch.get_device_module(self.device).set_device(self.gpu_id)
Zhang, Liangang's avatar
Zhang, Liangang committed
220
221
        if self.device == "cuda":
            backend = "nccl"
222
        elif self.device == "xpu":
223
            # TODO(liangan1): Just use gloo to bypass the initilization fail
224
            # Need to use xccl for xpu backend in the future
225
            backend = "gloo"
226
227
        elif self.device == "hpu":
            backend = "hccl"
228
229
        elif self.device == "cpu":
            backend = "gloo"
230

231
        if not self.server_args.enable_p2p_check:
232
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
233
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
234
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
235
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
236
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
237
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
238
239

        if not self.is_draft_worker:
Mick's avatar
Mick committed
240
            # Only initialize the distributed environment on the target model worker.
241
242
243
244
245
246
247
248
            init_distributed_environment(
                backend=backend,
                world_size=self.tp_size,
                rank=self.tp_rank,
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
            )
            initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
249
250
251
252
253
254
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
            )
255

256
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
257
            self.device, self.gpu_id, distributed=self.tp_size > 1
258
        )
259
        self.tp_group = get_tp_group()
260
        self.attention_tp_group = get_attention_tp_group()
261

262
        # Check memory for tensor parallelism
263
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
264
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
265
            if min_per_gpu_memory < local_gpu_memory * 0.9:
266
267
268
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
269

270
        return min_per_gpu_memory
271

Lianmin Zheng's avatar
Lianmin Zheng committed
272
    def load_model(self):
273
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
274
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
275
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
276
277

        # This can reduce thread conflicts and speed up weight loading.
278
279
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
280
281
282
283
284
285
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
286
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
287
288
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
289

290
        # Prepare the model config
291
292
293
294
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
295
296
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
297
298

        # Load the model
299
300
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
301
302
303
304
305
306
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
307
        monkey_patch_vllm_parallel_state(reverse=True)
308

bjmsong's avatar
bjmsong committed
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

332
        # Parse other args
333
        self.sliding_window_size = (
334
335
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
336
337
            else None
        )
338
        self.dtype = self.model_config.dtype
339

340
        logger.info(
341
            f"Load weight end. "
342
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
343
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
344
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
345
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
346

347
348
349
350
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
351
        from sglang.srt.model_loader.loader import (
352
353
354
355
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
356
        from sglang.srt.model_loader.utils import set_default_torch_dtype
357
358

        logger.info(
Chayenne's avatar
Chayenne committed
359
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
360
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
361
362
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
363
        target_device = torch.device(self.device)
364
        self.model_config.model_path = model_path
365
366
367
368
369
        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
370
371
            message = f"Failed to get model loader: {loader}."
            return False, message
372
373
374

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
375
                DefaultModelLoader.Source(
376
                    config.model_path,
377
378
379
380
381
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
382
383
384
385
386
387
388
389
390
391
392
393
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

394
        with set_default_torch_dtype(self.model_config.dtype):
395
            try:
396
                iter = get_weight_iter(self.model_config)
397
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
398
                message = f"Failed to get weights iterator: {e}."
399
400
401
402
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
403
404
405
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
406
407
                del iter
                gc.collect()
408
                iter = get_weight_iter(self.model_config)
409
410
411
412
413
414
415
416
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

417
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
418
        return True, "Succeeded to update model weights."
419

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
448
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

499
500
501
    def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
        self.model.load_weights(named_tensors)
        return True, "Success"
502

503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

520
521
522
523
524
525
526
527
528
529
530
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

531
    def profile_max_num_token(self, total_gpu_memory: int):
532
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
533
            self.device, self.gpu_id, distributed=self.tp_size > 1
534
        )
535
536
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
537
            and not self.server_args.disable_mla
538
539
540
541
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
542
                * torch._utils._element_size(self.kv_cache_dtype)
543
544
545
            )
        else:
            cell_size = (
546
                self.model_config.get_num_kv_heads(get_attention_tp_size())
547
548
549
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
550
                * torch._utils._element_size(self.kv_cache_dtype)
551
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
552
553
554
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
555
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
556
557
        return max_num_token

558
    def init_memory_pool(
559
560
        self,
        total_gpu_memory: int,
561
562
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
563
    ):
564
565
566
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
HAI's avatar
HAI committed
567
568
569
570
            if is_hip():  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
571
572
573
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
574
575
576
577
578
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

579
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
            else:
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
                    + max_num_reqs * self.server_args.speculative_num_steps
                    + 100
                )

602
603
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
604
                logging.warning(
605
606
607
608
609
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
610

611
        if self.max_total_num_tokens <= 0:
612
            raise RuntimeError(
613
                "Not enough memory. Please try to increase --mem-fraction-static."
614
            )
615

Liangsheng Yin's avatar
Liangsheng Yin committed
616
        self.req_to_token_pool = ReqToTokenPool(
617
618
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
619
            device=self.device,
620
            enable_memory_saver=self.server_args.enable_memory_saver,
Lianmin Zheng's avatar
Lianmin Zheng committed
621
        )
622
623
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
624
            and not self.server_args.disable_mla
625
626
627
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
628
                dtype=self.kv_cache_dtype,
629
630
631
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
632
                device=self.device,
633
                enable_memory_saver=self.server_args.enable_memory_saver,
634
            )
Shuo Yang's avatar
Shuo Yang committed
635
636
637
638
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
639
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
640
641
642
643
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
644
                enable_memory_saver=self.server_args.enable_memory_saver,
Shuo Yang's avatar
Shuo Yang committed
645
            )
646
647
648
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
649
                dtype=self.kv_cache_dtype,
650
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
651
652
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
653
                device=self.device,
654
                enable_memory_saver=self.server_args.enable_memory_saver,
655
            )
656
        logger.info(
657
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
658
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
659
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
660

Lianmin Zheng's avatar
Lianmin Zheng committed
661
662
663
664
665
666
667
668
669
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

670
671
672
673
674
675
676
677
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
678
            )
679
            assert not self.model_config.is_encoder_decoder, (
680
681
682
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
683
684
685
686
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
687
688
        elif self.server_args.attention_backend == "torch_native":
            self.attn_backend = TorchNativeAttnBackend(self)
689
        else:
690
691
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
692
            )
693

Shuo Yang's avatar
Shuo Yang committed
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

711
    def init_cuda_graphs(self):
712
        """Capture cuda graphs."""
713
714
715
716
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

717
718
719
720
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

721
722
        if self.server_args.disable_cuda_graph:
            return
723

724
        tic = time.time()
725
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
726
        self.cuda_graph_runner = CudaGraphRunner(self)
727
        logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
728

729
730
731
732
733
734
735
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

736
    def forward_decode(self, forward_batch: ForwardBatch):
737
        self.attn_backend.init_forward_metadata(forward_batch)
738
        return self.model.forward(
739
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
740
741
        )

742
    def forward_extend(self, forward_batch: ForwardBatch):
743
        self.attn_backend.init_forward_metadata(forward_batch)
744
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
745
746
747
748
749
750
751
752
753
754
755
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
756
757
758
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
759
760
761
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
762
763
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
764

Ke Bao's avatar
Ke Bao committed
765
766
767
768
769
    def forward_idle(self, forward_batch: ForwardBatch):
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

770
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
771
772
773
774
775
776
777
        if (
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
        ):
            return self.cuda_graph_runner.replay(forward_batch)

778
779
780
781
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
782
783
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
784
        else:
785
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
786

787
788
789
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
790
        # Apply logit bias
791
        sampling_info = forward_batch.sampling_info
792
793
794
795
796
797
798
799
800
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
            sampling_info.update_penalties()
801
802
803
804
805
806
807
808
809
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
            sampling_info,
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
        )
810
811
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
812
813
814
815
816
817
818
819
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"