model_runner.py 28.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import gc
Cody Yu's avatar
Cody Yu committed
17
import importlib
18
import importlib.resources
19
import inspect
Shuo Yang's avatar
Shuo Yang committed
20
import json
21
22
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
24
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
28
29
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
30
31
32
33
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
34
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
35
)
36
from vllm.distributed.parallel_state import in_the_same_node_as
37
from vllm.model_executor.model_loader import get_model
38
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
Shuo Yang's avatar
Shuo Yang committed
41
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
42
43
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
44
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
45
from sglang.srt.layers.sampler import Sampler
46
from sglang.srt.lora.lora_manager import LoRAManager
47
from sglang.srt.managers.schedule_batch import global_server_args_dict
48
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
49
    DoubleSparseTokenToKVPool,
50
51
52
53
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
54
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
55
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
56
from sglang.srt.server_args import ServerArgs
57
from sglang.srt.utils import (
58
    crash_on_warnings,
59
    enable_show_time_cost,
60
    get_available_gpu_memory,
HAI's avatar
HAI committed
61
    is_hip,
62
    monkey_patch_vllm_model_config,
63
    monkey_patch_vllm_p2p_access_check,
64
    set_cpu_offload_max_bytes,
65
)
66

Ying Sheng's avatar
Ying Sheng committed
67
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
68

Lianmin Zheng's avatar
Lianmin Zheng committed
69
70

class ModelRunner:
71
72
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
73
74
    def __init__(
        self,
75
        model_config: ModelConfig,
76
77
78
79
80
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
81
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
82
    ):
83
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
86
        self.device = server_args.device
87
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
90
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
91
        self.server_args = server_args
92
93
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
Ke Bao's avatar
Ke Bao committed
94

95
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
96
97
98
99
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
100
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
101
102
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

117
        if self.is_multimodal:
118
            logger.info(
119
120
121
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
122
            self.mem_fraction_static *= 0.95
123
            # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
Yineng Zhang's avatar
Yineng Zhang committed
124
125
126
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
127
                server_args.disable_radix_cache = True
128

129
130
131
132
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
133
134
            from outlines.caching import disable_cache

135
136
            disable_cache()

137
138
        global_server_args_dict.update(
            {
139
140
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
141
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
142
                "disable_mla": server_args.disable_mla,
143
                "torchao_config": server_args.torchao_config,
144
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
145
                "enable_dp_attention": server_args.enable_dp_attention,
146
147
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
148

149
150
151
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

        # Init components
152
        min_per_gpu_memory = self.init_torch_distributed()
153
        self.sampler = Sampler()
154
        self.load_model()
155
156
157
158
159
160
161
162
163

        # Apply torch TP if model supports it
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

164
165
        if server_args.lora_paths is not None:
            self.init_lora_manager()
166
167
        self.init_memory_pool(
            min_per_gpu_memory,
168
            server_args.max_running_requests,
169
170
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
171
172
173
174
175
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
176
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
177
            self.init_attention_backend()
178
179

    def init_torch_distributed(self):
180
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
181
        # Init torch distributed
182
        torch.get_device_module(self.device).set_device(self.gpu_id)
Zhang, Liangang's avatar
Zhang, Liangang committed
183
184
        if self.device == "cuda":
            backend = "nccl"
185
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
186
187
188
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            backend = "gloo"
189
190
        elif self.device == "hpu":
            backend = "hccl"
191

192
        if not self.server_args.enable_p2p_check:
193
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
194
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
195
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
196
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
197
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
198
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
199
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
200
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
201
202
            world_size=self.tp_size,
            rank=self.tp_rank,
203
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
204
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
205
206
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
207
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
208
            self.device, self.gpu_id, distributed=self.tp_size > 1
209
        )
210
        self.tp_group = get_tp_group()
211

212
213
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
Zhang, Liangang's avatar
Zhang, Liangang committed
214
215
216
        if self.device == "cuda" and not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        ):
217
218
219
220
221
222
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
223
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
224
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
225
            if min_per_gpu_memory < local_gpu_memory * 0.9:
226
227
228
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
229

230
        return min_per_gpu_memory
231

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    def setup_model(self):
        try:
            from vllm.config import VllmConfig

            vllm_config = VllmConfig()
            vllm_config.model_config = self.vllm_model_config
            vllm_config.load_config = self.load_config
            vllm_config.device_config = DeviceConfig(self.device)
            vllm_config.quant_config = VllmConfig._get_quantization_config(
                vllm_config.model_config, vllm_config.load_config
            )
            return get_model(vllm_config=vllm_config)
        except ImportError:
            return get_model(
                model_config=self.vllm_model_config,
                load_config=self.load_config,
                device_config=DeviceConfig(self.device),
                parallel_config=None,
                scheduler_config=None,
                lora_config=None,
                cache_config=None,
            )

    def get_model_config_params(self):
        sig = inspect.signature(VllmModelConfig.__init__)
        params = {
            "model": self.server_args.model_path,
            "quantization": self.server_args.quantization,
            "tokenizer": None,
            "tokenizer_mode": None,
            "trust_remote_code": self.server_args.trust_remote_code,
            "dtype": self.server_args.dtype,
            "seed": self.server_args.random_seed,
            "skip_tokenizer_init": True,
        }

        if "task" in sig.parameters:
            params["task"] = ""

        return params

Lianmin Zheng's avatar
Lianmin Zheng committed
273
    def load_model(self):
274
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
275
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
276
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
279

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
280
281
282
283
284
285
286
287
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
288

Lianmin Zheng's avatar
Lianmin Zheng committed
289
        # Prepare the vllm model config
290
291
292
293
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
294
295
        monkey_patch_vllm_model_config()
        self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
296
        if self.model_config.model_override_args is not None:
297
            self.vllm_model_config.hf_config.update(
298
                self.model_config.model_override_args
299
            )
300

301
302
        self.model = self.setup_model()

303
        self.sliding_window_size = (
304
305
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
306
307
            else None
        )
308
        self.dtype = self.vllm_model_config.dtype
309

310
        logger.info(
311
            f"Load weight end. "
312
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
313
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
314
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
315
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
316

317
318
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
319
320
321
322
323
324
325
326
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
327
            f"Update weights begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
328
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
329
330
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
331
        target_device = torch.device(self.device)
332
333

        try:
334
335
336
            model_config_params = self.get_model_config_params()
            model_config_params["model"] = model_path
            vllm_model_config = VllmModelConfig(**model_config_params)
337
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
338
339
            message = f"Failed to load model config: {e}."
            return False, message
340
341
342
343
344
345

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
            message = f"Failed to get model loader: {loader}."
            return False, message
348
349
350

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
351
352
353
354
355
356
357
                DefaultModelLoader.Source(
                    config.model,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
374
                message = f"Failed to get weights iterator: {e}."
375
376
377
378
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
379
380
381
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
382
383
384
385
386
387
388
389
390
391
392
393
394
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

395
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
396
        return True, "Succeeded to update model weights."
397

398
399
400
401
402
403
404
405
406
407
408
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

409
    def profile_max_num_token(self, total_gpu_memory: int):
410
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
411
            self.device, self.gpu_id, distributed=self.tp_size > 1
412
        )
413
414
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
415
            and not self.server_args.disable_mla
416
417
418
419
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
420
                * torch._utils._element_size(self.kv_cache_dtype)
421
422
423
424
425
426
427
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
428
                * torch._utils._element_size(self.kv_cache_dtype)
429
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
430
431
432
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
433
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
434
435
        return max_num_token

436
    def init_memory_pool(
437
438
        self,
        total_gpu_memory: int,
439
440
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
441
    ):
442
443
444
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
HAI's avatar
HAI committed
445
446
447
448
            if is_hip():  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
449
450
451
452
453
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

454
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
455
456
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
457
                logging.warning(
458
459
460
461
462
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
463

464
        if self.max_total_num_tokens <= 0:
465
            raise RuntimeError(
466
                "Not enough memory. Please try to increase --mem-fraction-static."
467
            )
468

Liangsheng Yin's avatar
Liangsheng Yin committed
469
        if max_num_reqs is None:
470
471
472
473
474
475
476
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
477
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
478
479
480
            )

        self.req_to_token_pool = ReqToTokenPool(
481
482
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
483
            device=self.device,
484
            use_records=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
485
        )
486
487
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
488
            and not self.server_args.disable_mla
489
490
491
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
492
                dtype=self.kv_cache_dtype,
493
494
495
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
496
                device=self.device,
497
            )
Shuo Yang's avatar
Shuo Yang committed
498
499
500
501
502
503
504
505
506
507
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
508
509
510
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
511
                dtype=self.kv_cache_dtype,
512
513
514
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
515
                device=self.device,
516
            )
517
        logger.info(
518
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
519
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
520
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
521

Lianmin Zheng's avatar
Lianmin Zheng committed
522
523
524
525
526
527
528
529
530
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

531
532
533
534
535
536
537
538
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
539
            )
540
            assert not self.model_config.is_encoder_decoder, (
541
542
543
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
544
545
546
547
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
548
        else:
549
550
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
551
            )
552

Shuo Yang's avatar
Shuo Yang committed
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    def init_double_sparsity_channel_config(self, selected_channel):

        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

571
    def init_cuda_graphs(self):
572
        """Capture cuda graphs."""
573
574
575
576
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

577
578
579
580
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

581
582
        if self.server_args.disable_cuda_graph:
            return
583

584
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
585
        self.cuda_graph_runner = CudaGraphRunner(self)
586

587
588
589
590
591
592
593
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

594
    def forward_decode(self, forward_batch: ForwardBatch):
595
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
596
            return self.cuda_graph_runner.replay(forward_batch)
597

598
599
        forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
        self.attn_backend.init_forward_metadata(forward_batch)
600
        return self.model.forward(
601
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
602
603
        )

604
    def forward_extend(self, forward_batch: ForwardBatch):
605
        self.attn_backend.init_forward_metadata(forward_batch)
606
607
        if self.is_generation:
            return self.model.forward(
608
                forward_batch.input_ids, forward_batch.positions, forward_batch
609
610
611
612
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
613
614
615
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
616
617
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
618

Ke Bao's avatar
Ke Bao committed
619
    def forward_idle(self, forward_batch: ForwardBatch):
620
621
622
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
            return self.cuda_graph_runner.replay(forward_batch)

Ke Bao's avatar
Ke Bao committed
623
624
625
626
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

627
628
629
630
631
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
632
633
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
634
        else:
635
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
636

637
638
639
640
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        sampling_info = forward_batch.sampling_info
641
642
643
644
645
646
647
648
649
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
            sampling_info.update_penalties()
650
651
652
653
654
655
656
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
657
658
659
660
661
662
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
663
            logits.add_(sampling_info.linear_penalties)
664
665
666
667
668
669
670
671
672
673
674

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
675
            sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
676
677
678

        return logits

Yineng Zhang's avatar
Yineng Zhang committed
679
680
681
682
683
684
685
686
687
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"

688
689
690
691
692
693
694
695

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
696
697
698
            try:
                module = importlib.import_module(name)
            except Exception as e:
699
700
701
                logger.warning(f"Ignore import error when loading {name}. {e}")
                if crash_on_warnings():
                    raise ValueError(f"Ignore import error when loading {name}. {e}")
702
                continue
703
            if hasattr(module, "EntryClass"):
704
                entry = module.EntryClass
705
706
707
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
708
                    for tmp in entry:
709
710
711
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
712
                        model_arch_name_to_cls[tmp.__name__] = tmp
713
                else:
714
715
716
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
717
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
718

719
720
721
722
723
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
724

725
726
727
728
729
730
731
732
733
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
734
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
735
736
737
738
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)