model_runner.py 31.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import gc
Shuo Yang's avatar
Shuo Yang committed
17
import json
18
import logging
19
import time
20
from typing import List, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22

import torch
23
import torch.distributed as dist
zhyncs's avatar
zhyncs committed
24
25
26
27
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
28
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
29
)
Lianmin Zheng's avatar
Lianmin Zheng committed
30

31
32
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
33
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
Shuo Yang's avatar
Shuo Yang committed
34
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
35
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
36
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
37
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
38
39
from sglang.srt.layers.dp_attention import (
    get_attention_tp_group,
40
    get_attention_tp_size,
41
42
    initialize_dp_attention,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
43
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
from sglang.srt.layers.sampler import Sampler
45
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
46
from sglang.srt.lora.lora_manager import LoRAManager
47
from sglang.srt.managers.schedule_batch import global_server_args_dict
48
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
49
    DoubleSparseTokenToKVPool,
50
51
52
53
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
54
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
55
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
56
from sglang.srt.server_args import ServerArgs
57
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
58
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
59
from sglang.srt.utils import (
60
    enable_show_time_cost,
61
    get_available_gpu_memory,
62
    init_custom_process_group,
bjmsong's avatar
bjmsong committed
63
    is_cuda,
HAI's avatar
HAI committed
64
    is_hip,
65
    monkey_patch_vllm_gguf_config,
66
    monkey_patch_vllm_p2p_access_check,
67
    set_cpu_offload_max_bytes,
68
)
69

Ying Sheng's avatar
Ying Sheng committed
70
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
71

Lianmin Zheng's avatar
Lianmin Zheng committed
72
73

class ModelRunner:
74
75
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
    def __init__(
        self,
78
        model_config: ModelConfig,
79
80
81
82
83
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
84
        server_args: ServerArgs,
85
        is_draft_worker: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
86
    ):
87
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
90
        self.device = server_args.device
91
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
94
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
95
        self.server_args = server_args
96
        self.is_draft_worker = is_draft_worker
97
98
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
99
        self.should_log = tp_rank == 0
100
101
102
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
Ke Bao's avatar
Ke Bao committed
103

104
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
105
106
107
108
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
109
110
111
112
            # TODO: add MLA optimization on CPU
            if self.server_args.device != "cpu":
                logger.info("MLA optimization is turned on. Use triton backend.")
                self.server_args.attention_backend = "triton"
Ke Bao's avatar
Ke Bao committed
113

Shuo Yang's avatar
Shuo Yang committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

128
        if self.is_multimodal:
Lianmin Zheng's avatar
Lianmin Zheng committed
129
            self.mem_fraction_static *= 0.95
130
131
132
133
134
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
            )

135
136
137
138
139
            if self.model_config.hf_config.architectures == [
                "MllamaForConditionalGeneration"
            ]:
                logger.info("Automatically turn off --chunked-prefill-size for mllama.")
                server_args.chunked_prefill_size = -1
140

Yineng Zhang's avatar
Yineng Zhang committed
141
142
143
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
144
                # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
145
146
147
148
                logger.info(
                    "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
                )
                server_args.chunked_prefill_size = -1
149
                server_args.disable_radix_cache = True
150

151
152
153
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
154
        if server_args.disable_outlines_disk_cache:
155
156
            from outlines.caching import disable_cache

157
158
            disable_cache()

159
160
        global_server_args_dict.update(
            {
161
162
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
163
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
164
                "disable_mla": server_args.disable_mla,
165
                "torchao_config": server_args.torchao_config,
166
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
167
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
168
                "enable_ep_moe": server_args.enable_ep_moe,
169
                "device": server_args.device,
170
171
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
172

173
174
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

175
        # Get memory before model loading
176
        min_per_gpu_memory = self.init_torch_distributed()
177

178
179
180
181
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=self.server_args.enable_memory_saver
        )

182
        # Load the model
183
        self.sampler = Sampler()
184
        self.load_model()
185

186
187
188
189
190
        # Apply torchao quantization
        apply_torchao_config_to_model(
            self.model, global_server_args_dict["torchao_config"]
        )

191
        # Apply torch TP if the model supports it
192
193
194
195
196
197
198
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

199
        # Init memory pool and attention backends
200
201
        if server_args.lora_paths is not None:
            self.init_lora_manager()
202
203
        self.init_memory_pool(
            min_per_gpu_memory,
204
            server_args.max_running_requests,
205
206
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
207
208
209
210
211
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
212
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
213
            self.init_attention_backend()
214
215

    def init_torch_distributed(self):
216
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
217
        # Init torch distributed
218
        torch.get_device_module(self.device).set_device(self.gpu_id)
Zhang, Liangang's avatar
Zhang, Liangang committed
219
220
        if self.device == "cuda":
            backend = "nccl"
221
        elif self.device == "xpu":
222
            # TODO(liangan1): Just use gloo to bypass the initilization fail
223
            # Need to use xccl for xpu backend in the future
224
            backend = "gloo"
225
226
        elif self.device == "hpu":
            backend = "hccl"
227
228
        elif self.device == "cpu":
            backend = "gloo"
229

230
        if not self.server_args.enable_p2p_check:
231
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
232
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
233
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
234
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
235
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
236
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
237
238
239
240
241
242
243
244
245
246
247

        if not self.is_draft_worker:
            # Only initilzie the distributed environment on the target model worker.
            init_distributed_environment(
                backend=backend,
                world_size=self.tp_size,
                rank=self.tp_rank,
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
            )
            initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
248
249
250
251
252
253
            initialize_dp_attention(
                enable_dp_attention=self.server_args.enable_dp_attention,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                dp_size=self.server_args.dp_size,
            )
254

255
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
256
            self.device, self.gpu_id, distributed=self.tp_size > 1
257
        )
258
        self.tp_group = get_tp_group()
259
        self.attention_tp_group = get_attention_tp_group()
260

261
        # Check memory for tensor parallelism
262
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
263
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
264
            if min_per_gpu_memory < local_gpu_memory * 0.9:
265
266
267
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
268

269
        return min_per_gpu_memory
270

Lianmin Zheng's avatar
Lianmin Zheng committed
271
    def load_model(self):
272
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
273
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
274
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
275
276

        # This can reduce thread conflicts and speed up weight loading.
277
278
        if self.device != "cpu":
            torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
279
280
281
282
283
284
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
285
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
286
287
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
288

289
        # Prepare the model config
290
291
292
293
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
294
295
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
296
297

        # Load the model
298
299
300
301
302
303
        with self.memory_saver_adapter.region():
            self.model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
            )
304

bjmsong's avatar
bjmsong committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

328
        # Parse other args
329
        self.sliding_window_size = (
330
331
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
332
333
            else None
        )
334
        self.dtype = self.model_config.dtype
335

336
        logger.info(
337
            f"Load weight end. "
338
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
339
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
340
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
341
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
342

343
344
345
346
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
347
        from sglang.srt.model_loader.loader import (
348
349
350
351
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
352
        from sglang.srt.model_loader.utils import set_default_torch_dtype
353
354

        logger.info(
Chayenne's avatar
Chayenne committed
355
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
356
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
357
358
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
359
        target_device = torch.device(self.device)
360
        self.model_config.model_path = model_path
361
362
363
364
365
        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
366
367
            message = f"Failed to get model loader: {loader}."
            return False, message
368
369
370

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
371
                DefaultModelLoader.Source(
372
                    config.model_path,
373
374
375
376
377
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
378
379
380
381
382
383
384
385
386
387
388
389
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

390
        with set_default_torch_dtype(self.model_config.dtype):
391
            try:
392
                iter = get_weight_iter(self.model_config)
393
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
394
                message = f"Failed to get weights iterator: {e}."
395
396
397
398
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
399
400
401
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
402
403
                del iter
                gc.collect()
404
                iter = get_weight_iter(self.model_config)
405
406
407
408
409
410
411
412
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

413
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
414
        return True, "Succeeded to update model weights."
415

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
444
            f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

495
496
497
    def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
        self.model.load_weights(named_tensors)
        return True, "Success"
498

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

516
517
518
519
520
521
522
523
524
525
526
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

527
    def profile_max_num_token(self, total_gpu_memory: int):
528
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
529
            self.device, self.gpu_id, distributed=self.tp_size > 1
530
        )
531
532
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
533
            and not self.server_args.disable_mla
534
535
536
537
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
538
                * torch._utils._element_size(self.kv_cache_dtype)
539
540
541
            )
        else:
            cell_size = (
542
                self.model_config.get_num_kv_heads(get_attention_tp_size())
543
544
545
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
546
                * torch._utils._element_size(self.kv_cache_dtype)
547
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
548
549
550
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
551
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
552
553
        return max_num_token

554
    def init_memory_pool(
555
556
        self,
        total_gpu_memory: int,
557
558
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
559
    ):
560
561
562
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
HAI's avatar
HAI committed
563
564
565
566
            if is_hip():  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
bjmsong's avatar
bjmsong committed
567
568
569
        elif self.server_args.kv_cache_dtype == "fp8_e4m3":
            if is_cuda():
                self.kv_cache_dtype = torch.float8_e4m3fn
570
571
572
573
574
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

575
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
            else:
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
                    + max_num_reqs * self.server_args.speculative_num_steps
                    + 100
                )

598
599
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
600
                logging.warning(
601
602
603
604
605
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
606

607
        if self.max_total_num_tokens <= 0:
608
            raise RuntimeError(
609
                "Not enough memory. Please try to increase --mem-fraction-static."
610
            )
611

Liangsheng Yin's avatar
Liangsheng Yin committed
612
        self.req_to_token_pool = ReqToTokenPool(
613
614
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
615
            device=self.device,
616
            use_records=False,
617
            enable_memory_saver=self.server_args.enable_memory_saver,
Lianmin Zheng's avatar
Lianmin Zheng committed
618
        )
619
620
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
621
            and not self.server_args.disable_mla
622
623
624
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
625
                dtype=self.kv_cache_dtype,
626
627
628
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
629
                device=self.device,
630
                enable_memory_saver=self.server_args.enable_memory_saver,
631
            )
Shuo Yang's avatar
Shuo Yang committed
632
633
634
635
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
636
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
Shuo Yang's avatar
Shuo Yang committed
637
638
639
640
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
641
                enable_memory_saver=self.server_args.enable_memory_saver,
Shuo Yang's avatar
Shuo Yang committed
642
            )
643
644
645
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
646
                dtype=self.kv_cache_dtype,
647
                head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
648
649
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
650
                device=self.device,
651
                enable_memory_saver=self.server_args.enable_memory_saver,
652
            )
653
        logger.info(
654
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
655
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
656
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
657

Lianmin Zheng's avatar
Lianmin Zheng committed
658
659
660
661
662
663
664
665
666
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

667
668
669
670
671
672
673
674
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
675
            )
676
            assert not self.model_config.is_encoder_decoder, (
677
678
679
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
680
681
682
683
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
684
685
        elif self.server_args.attention_backend == "torch_native":
            self.attn_backend = TorchNativeAttnBackend(self)
686
        else:
687
688
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
689
            )
690

Shuo Yang's avatar
Shuo Yang committed
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

708
    def init_cuda_graphs(self):
709
        """Capture cuda graphs."""
710
711
712
713
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

714
715
716
717
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

718
719
        if self.server_args.disable_cuda_graph:
            return
720

721
        tic = time.time()
722
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
723
        self.cuda_graph_runner = CudaGraphRunner(self)
724
        logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
725

726
727
728
729
730
731
732
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

733
    def forward_decode(self, forward_batch: ForwardBatch):
734
        self.attn_backend.init_forward_metadata(forward_batch)
735
        return self.model.forward(
736
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
737
738
        )

739
    def forward_extend(self, forward_batch: ForwardBatch):
740
        self.attn_backend.init_forward_metadata(forward_batch)
741
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
742
743
744
745
746
747
748
749
750
751
752
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
753
754
755
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
756
757
758
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
759
760
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
761

Ke Bao's avatar
Ke Bao committed
762
763
764
765
766
    def forward_idle(self, forward_batch: ForwardBatch):
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

767
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
768
769
770
771
772
773
774
        if (
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
        ):
            return self.cuda_graph_runner.replay(forward_batch)

775
776
777
778
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
779
780
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
781
        else:
782
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
783

784
785
786
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
787
        # Apply logit bias
788
        sampling_info = forward_batch.sampling_info
789
790
791
792
793
794
795
796
797
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
            sampling_info.update_penalties()
798
799
800
801
802
803
804
805
806
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
            sampling_info,
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
        )
807
808
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
809
810
811
812
813
814
815
816
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"