model_runner.py 30.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import gc
Cody Yu's avatar
Cody Yu committed
17
import importlib
18
import importlib.resources
19
import inspect
Shuo Yang's avatar
Shuo Yang committed
20
import json
21
22
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
24
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
28
29
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
30
31
32
33
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
34
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
35
)
36
from vllm.distributed.parallel_state import in_the_same_node_as
37
from vllm.model_executor.model_loader import get_model
38
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
Shuo Yang's avatar
Shuo Yang committed
41
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
42
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
43
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
44
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
45
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
46
from sglang.srt.layers.sampler import Sampler
47
from sglang.srt.lora.lora_manager import LoRAManager
48
from sglang.srt.managers.schedule_batch import global_server_args_dict
49
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
50
    DoubleSparseTokenToKVPool,
51
52
53
54
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
55
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
57
from sglang.srt.server_args import ServerArgs
58
from sglang.srt.utils import (
59
    crash_on_warnings,
60
    enable_show_time_cost,
61
    get_available_gpu_memory,
HAI's avatar
HAI committed
62
    is_hip,
63
    monkey_patch_vllm_gguf_config,
64
    monkey_patch_vllm_model_config,
65
    monkey_patch_vllm_p2p_access_check,
66
    set_cpu_offload_max_bytes,
67
)
68

Ying Sheng's avatar
Ying Sheng committed
69
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
70

Lianmin Zheng's avatar
Lianmin Zheng committed
71
72

class ModelRunner:
73
74
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
    def __init__(
        self,
77
        model_config: ModelConfig,
78
79
80
81
82
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
83
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
84
    ):
85
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
88
        self.device = server_args.device
89
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
90
91
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
92
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
93
        self.server_args = server_args
94
95
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
Ke Bao's avatar
Ke Bao committed
96

97
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
98
99
100
101
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
102
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
103
104
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

119
        if self.is_multimodal:
120
            logger.info(
121
122
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
123
            server_args.chunked_prefill_size = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
124
            self.mem_fraction_static *= 0.95
125
            # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
Yineng Zhang's avatar
Yineng Zhang committed
126
127
128
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
129
                server_args.disable_radix_cache = True
130

131
132
133
134
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
135
136
            from outlines.caching import disable_cache

137
138
            disable_cache()

139
140
        global_server_args_dict.update(
            {
141
142
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
143
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
144
                "disable_mla": server_args.disable_mla,
145
                "torchao_config": server_args.torchao_config,
146
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
147
                "enable_dp_attention": server_args.enable_dp_attention,
148
149
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
150

151
152
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

153
        # Get memory before model loading
154
        min_per_gpu_memory = self.init_torch_distributed()
155
156

        # Load the model
157
        self.sampler = Sampler()
158
        self.load_model()
159

160
        # Apply torch TP if the model supports it
161
162
163
164
165
166
167
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

168
        # Init memory pool and attention backends
169
170
        if server_args.lora_paths is not None:
            self.init_lora_manager()
171
172
        self.init_memory_pool(
            min_per_gpu_memory,
173
            server_args.max_running_requests,
174
175
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
176
177
178
179
180
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
181
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
182
            self.init_attention_backend()
183
184

    def init_torch_distributed(self):
185
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
186
        # Init torch distributed
187
        torch.get_device_module(self.device).set_device(self.gpu_id)
Zhang, Liangang's avatar
Zhang, Liangang committed
188
189
        if self.device == "cuda":
            backend = "nccl"
190
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
191
192
193
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            backend = "gloo"
194
195
        elif self.device == "hpu":
            backend = "hccl"
196

197
        if not self.server_args.enable_p2p_check:
198
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
199
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
200
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
201
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
202
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
203
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
204
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
205
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
206
207
            world_size=self.tp_size,
            rank=self.tp_rank,
208
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
209
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
210
211
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
212
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
213
            self.device, self.gpu_id, distributed=self.tp_size > 1
214
        )
215
        self.tp_group = get_tp_group()
216

217
218
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
Zhang, Liangang's avatar
Zhang, Liangang committed
219
220
221
        if self.device == "cuda" and not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        ):
222
223
224
225
226
227
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
228
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
229
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
230
            if min_per_gpu_memory < local_gpu_memory * 0.9:
231
232
233
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
234

235
        return min_per_gpu_memory
236

237
238
239
240
241
242
243
244
245
246
247
248
249
    def setup_model(self):
        try:
            from vllm.config import VllmConfig

            vllm_config = VllmConfig()
            vllm_config.model_config = self.vllm_model_config
            vllm_config.load_config = self.load_config
            vllm_config.device_config = DeviceConfig(self.device)
            vllm_config.quant_config = VllmConfig._get_quantization_config(
                vllm_config.model_config, vllm_config.load_config
            )
            return get_model(vllm_config=vllm_config)
        except ImportError:
Lianmin Zheng's avatar
Lianmin Zheng committed
250
251
252
253
254
255
256
257
258
259
260
            pass

        return get_model(
            model_config=self.vllm_model_config,
            load_config=self.load_config,
            device_config=DeviceConfig(self.device),
            parallel_config=None,
            scheduler_config=None,
            lora_config=None,
            cache_config=None,
        )
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

    def get_model_config_params(self):
        sig = inspect.signature(VllmModelConfig.__init__)
        params = {
            "model": self.server_args.model_path,
            "quantization": self.server_args.quantization,
            "tokenizer": None,
            "tokenizer_mode": None,
            "trust_remote_code": self.server_args.trust_remote_code,
            "dtype": self.server_args.dtype,
            "seed": self.server_args.random_seed,
            "skip_tokenizer_init": True,
        }

        if "task" in sig.parameters:
            params["task"] = ""

        return params

Lianmin Zheng's avatar
Lianmin Zheng committed
280
    def load_model(self):
281
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
282
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
283
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
284
285
286

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
287
288
289
290
291
292
293
294
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
295

Lianmin Zheng's avatar
Lianmin Zheng committed
296
        # Prepare the vllm model config
297
298
299
300
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
301
        monkey_patch_vllm_model_config()
302
303
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
304
        self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
305
        if self.model_config.model_override_args is not None:
306
            self.vllm_model_config.hf_config.update(
307
                self.model_config.model_override_args
308
            )
309

310
311
        self.model = self.setup_model()

312
        self.sliding_window_size = (
313
314
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
315
316
            else None
        )
317
        self.dtype = self.vllm_model_config.dtype
318

319
        logger.info(
320
            f"Load weight end. "
321
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
322
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
323
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
324
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
325

Chayenne's avatar
Chayenne committed
326
327
    def update_weights_from_disk(self, model_path: str, load_format: str):
        """Update engine weights online from disk."""
328
329
330
331
332
333
334
335
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
Chayenne's avatar
Chayenne committed
336
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
337
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
338
339
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
340
        target_device = torch.device(self.device)
341
342

        try:
343
344
345
            model_config_params = self.get_model_config_params()
            model_config_params["model"] = model_path
            vllm_model_config = VllmModelConfig(**model_config_params)
346
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
347
348
            message = f"Failed to load model config: {e}."
            return False, message
349
350
351
352
353
354

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
355
356
            message = f"Failed to get model loader: {loader}."
            return False, message
357
358
359

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
360
361
362
363
364
365
366
                DefaultModelLoader.Source(
                    config.model,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
383
                message = f"Failed to get weights iterator: {e}."
384
385
386
387
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
390
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
391
392
393
394
395
396
397
398
399
400
401
402
403
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

404
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
405
        return True, "Succeeded to update model weights."
406

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

424
425
426
427
428
429
430
431
432
433
434
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

435
    def profile_max_num_token(self, total_gpu_memory: int):
436
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
437
            self.device, self.gpu_id, distributed=self.tp_size > 1
438
        )
439
440
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
441
            and not self.server_args.disable_mla
442
443
444
445
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
446
                * torch._utils._element_size(self.kv_cache_dtype)
447
448
449
450
451
452
453
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
454
                * torch._utils._element_size(self.kv_cache_dtype)
455
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
456
457
458
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
459
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
460
461
        return max_num_token

462
    def init_memory_pool(
463
464
        self,
        total_gpu_memory: int,
465
466
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
467
    ):
468
469
470
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
HAI's avatar
HAI committed
471
472
473
474
            if is_hip():  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
475
476
477
478
479
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

480
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
481
482
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
483
                logging.warning(
484
485
486
487
488
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
489

490
        if self.max_total_num_tokens <= 0:
491
            raise RuntimeError(
492
                "Not enough memory. Please try to increase --mem-fraction-static."
493
            )
494

Liangsheng Yin's avatar
Liangsheng Yin committed
495
        if max_num_reqs is None:
496
497
498
499
500
501
502
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
503
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
504
505
506
            )

        self.req_to_token_pool = ReqToTokenPool(
507
508
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
509
            device=self.device,
510
            use_records=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
511
        )
512
513
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
514
            and not self.server_args.disable_mla
515
516
517
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
518
                dtype=self.kv_cache_dtype,
519
520
521
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
522
                device=self.device,
523
            )
Shuo Yang's avatar
Shuo Yang committed
524
525
526
527
528
529
530
531
532
533
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
534
535
536
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
537
                dtype=self.kv_cache_dtype,
538
539
540
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
541
                device=self.device,
542
            )
543
        logger.info(
544
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
545
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
546
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
547

Lianmin Zheng's avatar
Lianmin Zheng committed
548
549
550
551
552
553
554
555
556
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

557
558
559
560
561
562
563
564
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
565
            )
566
            assert not self.model_config.is_encoder_decoder, (
567
568
569
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
570
571
572
573
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
574
575
        elif self.server_args.attention_backend == "torch_native":
            self.attn_backend = TorchNativeAttnBackend(self)
576
        else:
577
578
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
579
            )
580

Shuo Yang's avatar
Shuo Yang committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
    def init_double_sparsity_channel_config(self, selected_channel):

        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

599
    def init_cuda_graphs(self):
600
        """Capture cuda graphs."""
601
602
603
604
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

605
606
607
608
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

609
610
        if self.server_args.disable_cuda_graph:
            return
611

612
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
613
        self.cuda_graph_runner = CudaGraphRunner(self)
614

615
616
617
618
619
620
621
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

622
    def forward_decode(self, forward_batch: ForwardBatch):
623
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
624
            return self.cuda_graph_runner.replay(forward_batch)
625

626
627
        forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
        self.attn_backend.init_forward_metadata(forward_batch)
628
        return self.model.forward(
629
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
630
631
        )

632
    def forward_extend(self, forward_batch: ForwardBatch):
633
        self.attn_backend.init_forward_metadata(forward_batch)
634
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
635
636
637
638
639
640
641
642
643
644
645
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
646
647
648
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
649
650
651
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
652
653
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
654

Ke Bao's avatar
Ke Bao committed
655
    def forward_idle(self, forward_batch: ForwardBatch):
656
657
658
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
            return self.cuda_graph_runner.replay(forward_batch)

Ke Bao's avatar
Ke Bao committed
659
660
661
662
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

663
664
665
666
667
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
668
669
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
670
        else:
671
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
672

673
674
675
676
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        sampling_info = forward_batch.sampling_info
677
678
679
680
681
682
683
684
685
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
            sampling_info.update_penalties()
686
687
688
689
690
691
692
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
693
694
695
696
697
698
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
699
            logits.add_(sampling_info.linear_penalties)
700
701
702
703
704
705
706
707
708
709
710

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
711
            sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
712
713
714

        return logits

Yineng Zhang's avatar
Yineng Zhang committed
715
716
717
718
719
720
721
722
723
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"

724
725
726
727
728
729
730
731

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
732
733
734
            try:
                module = importlib.import_module(name)
            except Exception as e:
735
736
737
                logger.warning(f"Ignore import error when loading {name}. {e}")
                if crash_on_warnings():
                    raise ValueError(f"Ignore import error when loading {name}. {e}")
738
                continue
739
            if hasattr(module, "EntryClass"):
740
                entry = module.EntryClass
741
742
743
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
744
                    for tmp in entry:
745
746
747
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
748
                        model_arch_name_to_cls[tmp.__name__] = tmp
749
                else:
750
751
752
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
753
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
754

755
756
757
758
759
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
760

761
762
763
764
765
766
767
768
769
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
770
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
771
772
773
774
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)