model_runner.py 28.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import gc
Shuo Yang's avatar
Shuo Yang committed
17
import json
18
import logging
19
import time
20
from typing import Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22

import torch
23
import torch.distributed as dist
zhyncs's avatar
zhyncs committed
24
25
26
27
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
28
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
29
)
Lianmin Zheng's avatar
Lianmin Zheng committed
30

31
32
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
33
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
Shuo Yang's avatar
Shuo Yang committed
34
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
35
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
36
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
37
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
38
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
39
from sglang.srt.layers.sampler import Sampler
40
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
41
from sglang.srt.lora.lora_manager import LoRAManager
42
from sglang.srt.managers.schedule_batch import global_server_args_dict
43
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
44
    DoubleSparseTokenToKVPool,
45
46
47
48
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
49
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
from sglang.srt.model_loader import get_model
51
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
52
from sglang.srt.server_args import ServerArgs
53
from sglang.srt.utils import (
54
    enable_show_time_cost,
55
    get_available_gpu_memory,
56
    init_custom_process_group,
HAI's avatar
HAI committed
57
    is_hip,
58
    monkey_patch_vllm_gguf_config,
59
    monkey_patch_vllm_p2p_access_check,
60
    set_cpu_offload_max_bytes,
61
)
62

Ying Sheng's avatar
Ying Sheng committed
63
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
64

Lianmin Zheng's avatar
Lianmin Zheng committed
65
66

class ModelRunner:
67
68
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
69
70
    def __init__(
        self,
71
        model_config: ModelConfig,
72
73
74
75
76
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
    ):
79
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
82
        self.device = server_args.device
83
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
86
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
87
        self.server_args = server_args
88
89
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
Ke Bao's avatar
Ke Bao committed
90

91
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
92
93
94
95
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
96
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
97
98
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

113
        if self.is_multimodal:
114
            server_args.chunked_prefill_size = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
115
            self.mem_fraction_static *= 0.95
116
            logger.info(
117
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
118
119
120
                f"and turn off chunked prefill "
                f"because this is a multimodal model."
            )
121
            # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
Yineng Zhang's avatar
Yineng Zhang committed
122
123
124
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
125
                server_args.disable_radix_cache = True
126

127
128
129
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
130
        if server_args.disable_outlines_disk_cache:
131
132
            from outlines.caching import disable_cache

133
134
            disable_cache()

135
136
        global_server_args_dict.update(
            {
137
138
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
139
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
140
                "disable_mla": server_args.disable_mla,
141
                "torchao_config": server_args.torchao_config,
142
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
143
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
144
                "enable_ep_moe": server_args.enable_ep_moe,
145
146
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
147

148
149
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

150
        # Get memory before model loading
151
        min_per_gpu_memory = self.init_torch_distributed()
152
153

        # Load the model
154
        self.sampler = Sampler()
155
        self.load_model()
156

157
        # Apply torch TP if the model supports it
158
159
160
161
162
163
164
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

165
166
        apply_torchao_config_to_model(
            self.model, global_server_args_dict["torchao_config"]
167
168
        )

169
        # Init memory pool and attention backends
170
171
        if server_args.lora_paths is not None:
            self.init_lora_manager()
172
173
        self.init_memory_pool(
            min_per_gpu_memory,
174
            server_args.max_running_requests,
175
176
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
177
178
179
180
181
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
182
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
183
            self.init_attention_backend()
184
185

    def init_torch_distributed(self):
186
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
187
        # Init torch distributed
188
        torch.get_device_module(self.device).set_device(self.gpu_id)
Zhang, Liangang's avatar
Zhang, Liangang committed
189
190
        if self.device == "cuda":
            backend = "nccl"
191
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
192
193
194
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            backend = "gloo"
195
196
        elif self.device == "hpu":
            backend = "hccl"
197

198
        if not self.server_args.enable_p2p_check:
199
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
200
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
201
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
202
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
203
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
204
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
205
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
206
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
207
208
            world_size=self.tp_size,
            rank=self.tp_rank,
209
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
210
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
211
212
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
213
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
214
            self.device, self.gpu_id, distributed=self.tp_size > 1
215
        )
216
        self.tp_group = get_tp_group()
217

218
        # Check memory for tensor parallelism
219
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
220
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
221
            if min_per_gpu_memory < local_gpu_memory * 0.9:
222
223
224
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
225

226
        return min_per_gpu_memory
227

Lianmin Zheng's avatar
Lianmin Zheng committed
228
    def load_model(self):
229
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
230
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
231
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
234

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
235
236
237
238
239
240
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
241
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
242
243
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
244

Lianmin Zheng's avatar
Lianmin Zheng committed
245
        # Prepare the vllm model config
246
247
248
249
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
250

251
252
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
253
254
255
256
257
        self.model = get_model(
            model_config=self.model_config,
            load_config=self.load_config,
            device_config=DeviceConfig(self.device),
        )
258

259
        self.sliding_window_size = (
260
261
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
262
263
            else None
        )
264
        self.dtype = self.model_config.dtype
265

266
        logger.info(
267
            f"Load weight end. "
268
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
269
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
270
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
271
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
272

Chayenne's avatar
Chayenne committed
273
274
    def update_weights_from_disk(self, model_path: str, load_format: str):
        """Update engine weights online from disk."""
275
        from sglang.srt.model_loader.loader import (
276
277
278
279
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
280
        from sglang.srt.model_loader.utils import set_default_torch_dtype
281
282

        logger.info(
Chayenne's avatar
Chayenne committed
283
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
284
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
285
286
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
287
        target_device = torch.device(self.device)
288
        self.model_config.model_path = model_path
289
290
291
292
293
        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
294
295
            message = f"Failed to get model loader: {loader}."
            return False, message
296
297
298

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
299
                DefaultModelLoader.Source(
300
                    config.model_path,
301
302
303
304
305
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
306
307
308
309
310
311
312
313
314
315
316
317
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

318
        with set_default_torch_dtype(self.model_config.dtype):
319
            try:
320
                iter = get_weight_iter(self.model_config)
321
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
322
                message = f"Failed to get weights iterator: {e}."
323
324
325
326
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
327
328
329
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
330
331
                del iter
                gc.collect()
332
                iter = get_weight_iter(self.model_config)
333
334
335
336
337
338
339
340
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

341
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
342
        return True, "Succeeded to update model weights."
343

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
            f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )
        current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

441
442
443
444
445
446
447
448
449
450
451
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

452
    def profile_max_num_token(self, total_gpu_memory: int):
453
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
454
            self.device, self.gpu_id, distributed=self.tp_size > 1
455
        )
456
457
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
458
            and not self.server_args.disable_mla
459
460
461
462
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
463
                * torch._utils._element_size(self.kv_cache_dtype)
464
465
466
467
468
469
470
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
471
                * torch._utils._element_size(self.kv_cache_dtype)
472
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
473
474
475
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
476
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
477
478
        return max_num_token

479
    def init_memory_pool(
480
481
        self,
        total_gpu_memory: int,
482
483
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
484
    ):
485
486
487
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
HAI's avatar
HAI committed
488
489
490
491
            if is_hip():  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
492
493
494
495
496
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

497
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
498
499
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
500
                logging.warning(
501
502
503
504
505
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
506

507
        if self.max_total_num_tokens <= 0:
508
            raise RuntimeError(
509
                "Not enough memory. Please try to increase --mem-fraction-static."
510
            )
511

Liangsheng Yin's avatar
Liangsheng Yin committed
512
        if max_num_reqs is None:
513
514
515
516
517
518
519
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
520
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
521
522
523
            )

        self.req_to_token_pool = ReqToTokenPool(
524
525
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
526
            device=self.device,
527
            use_records=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
528
        )
529
530
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
531
            and not self.server_args.disable_mla
532
533
534
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
535
                dtype=self.kv_cache_dtype,
536
537
538
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
539
                device=self.device,
540
            )
Shuo Yang's avatar
Shuo Yang committed
541
542
543
544
545
546
547
548
549
550
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
551
552
553
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
554
                dtype=self.kv_cache_dtype,
555
556
557
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
558
                device=self.device,
559
            )
560
        logger.info(
561
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
562
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
563
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
564

Lianmin Zheng's avatar
Lianmin Zheng committed
565
566
567
568
569
570
571
572
573
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

574
575
576
577
578
579
580
581
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
582
            )
583
            assert not self.model_config.is_encoder_decoder, (
584
585
586
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
587
588
589
590
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
591
592
        elif self.server_args.attention_backend == "torch_native":
            self.attn_backend = TorchNativeAttnBackend(self)
593
        else:
594
595
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
596
            )
597

Shuo Yang's avatar
Shuo Yang committed
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
    def init_double_sparsity_channel_config(self, selected_channel):

        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

616
    def init_cuda_graphs(self):
617
        """Capture cuda graphs."""
618
619
620
621
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

622
623
624
625
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

626
627
        if self.server_args.disable_cuda_graph:
            return
628

629
        tic = time.time()
630
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
631
        self.cuda_graph_runner = CudaGraphRunner(self)
632
        logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
633

634
635
636
637
638
639
640
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

641
    def forward_decode(self, forward_batch: ForwardBatch):
642
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
643
            return self.cuda_graph_runner.replay(forward_batch)
644

645
646
        forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
        self.attn_backend.init_forward_metadata(forward_batch)
647
        return self.model.forward(
648
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
649
650
        )

651
    def forward_extend(self, forward_batch: ForwardBatch):
652
        self.attn_backend.init_forward_metadata(forward_batch)
653
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
654
655
656
657
658
659
660
661
662
663
664
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
665
666
667
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
668
669
670
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
671
672
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
673

Ke Bao's avatar
Ke Bao committed
674
    def forward_idle(self, forward_batch: ForwardBatch):
675
676
677
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
            return self.cuda_graph_runner.replay(forward_batch)

Ke Bao's avatar
Ke Bao committed
678
679
680
681
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

682
683
684
685
686
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
687
688
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
689
        else:
690
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
691

692
693
694
695
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        sampling_info = forward_batch.sampling_info
696
697
698
699
700
701
702
703
704
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
            sampling_info.update_penalties()
705
706
707
708
709
710
711
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
712
713
714
715
716
717
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
718
            logits.add_(sampling_info.linear_penalties)
719
720
721
722
723
724
725
726
727
728
729

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
730
            sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
731
732
733

        return logits

Yineng Zhang's avatar
Yineng Zhang committed
734
735
736
737
738
739
740
741
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"