model_runner.py 26.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
import importlib.resources
Shuo Yang's avatar
Shuo Yang committed
21
import json
22
23
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
24
from functools import lru_cache
25
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27

import torch
28
29
30
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
31
32
33
34
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
35
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
36
)
37
from vllm.distributed.parallel_state import in_the_same_node_as
38
from vllm.model_executor.model_loader import get_model
39
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
40

41
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
42
from sglang.srt.constrained import disable_cache
Shuo Yang's avatar
Shuo Yang committed
43
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
44
45
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
46
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
47
from sglang.srt.layers.sampler import Sampler
48
from sglang.srt.lora.lora_manager import LoRAManager
49
from sglang.srt.managers.schedule_batch import global_server_args_dict
50
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
51
    DoubleSparseTokenToKVPool,
52
53
54
55
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
56
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
58
from sglang.srt.server_args import ServerArgs
59
from sglang.srt.utils import (
60
    enable_show_time_cost,
61
    get_available_gpu_memory,
62
63
    is_attention_free_model,
    is_embedding_model,
64
    is_generation_model,
65
    is_multimodal_model,
66
    model_has_inner_state,
67
    monkey_patch_vllm_dummy_weight_loader,
68
69
    monkey_patch_vllm_p2p_access_check,
)
70

Ying Sheng's avatar
Ying Sheng committed
71
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
72

Lianmin Zheng's avatar
Lianmin Zheng committed
73
74

class ModelRunner:
75
76
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
    def __init__(
        self,
79
        model_config: ModelConfig,
80
81
82
83
84
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
85
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
86
    ):
87
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
90
        self.device = server_args.device
91
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
94
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
95
        self.server_args = server_args
96
97
98
        self.is_multimodal_model = is_multimodal_model(
            self.model_config.hf_config.architectures
        )
Ke Bao's avatar
Ke Bao committed
99

100
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
101
102
103
104
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
105
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
106
107
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

122
123
124
125
126
127
        if self.is_multimodal_model:
            logger.info(
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
            server_args.mem_fraction_static *= 0.95
Yineng Zhang's avatar
Yineng Zhang committed
128
129
130
131
132
            # TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
                server_args.disable_cuda_graph = True
133

134
135
136
137
138
139
140
        if self.server_args.enable_overlap_schedule:
            logger.warning(
                "Overlap scheduler is enabled. This is an experimental feature. "
                "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
                "and embedding APIs are not supported and will lead to wrong results."
            )

141
142
143
144
145
146
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
            disable_cache()

147
148
        global_server_args_dict.update(
            {
149
150
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
151
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
152
                "disable_mla": server_args.disable_mla,
153
                "torchao_config": server_args.torchao_config,
154
                "disable_penalizer": server_args.disable_penalizer,
155
                "disable_nan_detection": server_args.disable_nan_detection,
156
157
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
158

159
        # Init componnets
160
        min_per_gpu_memory = self.init_torch_distributed()
161
        self.sampler = Sampler()
162
        self.load_model()
163
164
        if server_args.lora_paths is not None:
            self.init_lora_manager()
165
166
        self.init_memory_pool(
            min_per_gpu_memory,
167
            server_args.max_running_requests,
168
169
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
170
171
172
173
174
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
175
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
176
            self.init_attention_backend()
177
178

    def init_torch_distributed(self):
179
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
180
        # Init torch distributed
Zhang, Liangang's avatar
Zhang, Liangang committed
181
182
183
        if self.device == "cuda":
            torch.cuda.set_device(self.gpu_id)
            backend = "nccl"
184
185
186
187
188
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            torch.xpu.set_device(self.gpu_id)
            backend = "gloo"
189

190
        if not self.server_args.enable_p2p_check:
191
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
192
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
193
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
194
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
195
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
196
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
197
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
198
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
199
200
            world_size=self.tp_size,
            rank=self.tp_rank,
201
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
202
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
203
204
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
205
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
206
            self.device, self.gpu_id, distributed=self.tp_size > 1
207
        )
208
        self.tp_group = get_tp_group()
209

210
211
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
Zhang, Liangang's avatar
Zhang, Liangang committed
212
213
214
        if self.device == "cuda" and not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        ):
215
216
217
218
219
220
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
221
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
222
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
223
            if min_per_gpu_memory < local_gpu_memory * 0.9:
224
225
226
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
227

228
        return min_per_gpu_memory
229

Lianmin Zheng's avatar
Lianmin Zheng committed
230
    def load_model(self):
231
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
232
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
233
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
234
235
236

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
237
238
239
240
241
242
243
244
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
245

Lianmin Zheng's avatar
Lianmin Zheng committed
246
        # Prepare the vllm model config
247
        monkey_patch_vllm_dummy_weight_loader()
248
249
        self.load_config = LoadConfig(load_format=self.server_args.load_format)
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
250
251
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
252
253
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
254
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
255
            dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
256
            seed=self.server_args.random_seed,
257
258
            skip_tokenizer_init=True,
        )
259
        if self.model_config.model_override_args is not None:
260
            self.vllm_model_config.hf_config.update(
261
                self.model_config.model_override_args
262
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
263
        self.dtype = self.vllm_model_config.dtype
264

Lianmin Zheng's avatar
Lianmin Zheng committed
265
        # Load the model
266
        self.model = get_model(
267
268
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Zhang, Liangang's avatar
Zhang, Liangang committed
269
            device_config=DeviceConfig(self.device),
270
271
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
272
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
273
            cache_config=None,
274
        )
275
        self.sliding_window_size = (
276
277
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
278
279
            else None
        )
280
        self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
281
        self.is_generation = is_generation_model(
282
            self.model_config.hf_config.architectures, self.server_args.is_embedding
283
284
        )

285
        logger.info(
286
            f"Load weight end. "
287
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
288
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
289
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
290
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
291

292
293
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
294
295
296
297
298
299
300
301
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
302
            f"Update weights begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
303
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
304
305
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
306
        target_device = torch.device(self.device)
307
308

        try:
309
            # TODO: Use a better method to check this
310
311
312
313
314
315
316
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
317
                seed=self.server_args.random_seed,
318
319
320
                skip_tokenizer_init=True,
            )
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
321
322
            message = f"Failed to load model config: {e}."
            return False, message
323
324
325
326
327
328

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
329
330
            message = f"Failed to get model loader: {loader}."
            return False, message
331
332
333

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
334
335
336
337
338
339
340
                DefaultModelLoader.Source(
                    config.model,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
357
                message = f"Failed to get weights iterator: {e}."
358
359
360
361
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
362
363
364
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
365
366
367
368
369
370
371
372
373
374
375
376
377
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

378
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
379
        return True, "Succeeded to update model weights."
380

381
382
383
384
385
386
387
388
389
390
391
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

392
    def profile_max_num_token(self, total_gpu_memory: int):
393
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
394
            self.device, self.gpu_id, distributed=self.tp_size > 1
395
        )
396
397
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
398
            and not self.server_args.disable_mla
399
400
401
402
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
403
                * torch._utils._element_size(self.kv_cache_dtype)
404
405
406
407
408
409
410
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
411
                * torch._utils._element_size(self.kv_cache_dtype)
412
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
413
414
415
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
416
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
417
418
        return max_num_token

419
    def init_memory_pool(
420
421
        self,
        total_gpu_memory: int,
422
423
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
424
    ):
425
426
427
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
428
            self.kv_cache_dtype = torch.float8_e5m2
429
430
431
432
433
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

434
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
435
436
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
437
                logging.warning(
438
439
440
441
442
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
443

444
        if self.max_total_num_tokens <= 0:
445
            raise RuntimeError(
446
                "Not enough memory. Please try to increase --mem-fraction-static."
447
            )
448

Liangsheng Yin's avatar
Liangsheng Yin committed
449
        if max_num_reqs is None:
450
451
452
453
454
455
456
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
457
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
458
459
460
            )

        self.req_to_token_pool = ReqToTokenPool(
461
462
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
463
            device=self.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
464
        )
465
466
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
467
            and not self.server_args.disable_mla
468
469
470
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
471
                dtype=self.kv_cache_dtype,
472
473
474
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
475
                device=self.device,
476
            )
Shuo Yang's avatar
Shuo Yang committed
477
478
479
480
481
482
483
484
485
486
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
487
488
489
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
490
                dtype=self.kv_cache_dtype,
491
492
493
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
494
                device=self.device,
495
            )
496
        logger.info(
497
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
498
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
499
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
500

Lianmin Zheng's avatar
Lianmin Zheng committed
501
502
503
504
505
506
507
508
509
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

510
511
512
513
514
515
516
517
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
518
            )
519
520
521
522
            assert not self.has_cross_attention, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
523
524
525
526
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
527
        else:
528
529
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
530
            )
531

Shuo Yang's avatar
Shuo Yang committed
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
    def init_double_sparsity_channel_config(self, selected_channel):

        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

550
    def init_cuda_graphs(self):
551
        """Capture cuda graphs."""
552
553
554
555
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

556
557
558
559
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

560
561
        if self.server_args.disable_cuda_graph:
            return
562

563
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
564
        self.cuda_graph_runner = CudaGraphRunner(self)
565

566
    def forward_decode(self, forward_batch: ForwardBatch):
567
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
568
            forward_batch.batch_size
569
        ):
570
            return self.cuda_graph_runner.replay(forward_batch)
571

572
573
        forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
        self.attn_backend.init_forward_metadata(forward_batch)
574
        return self.model.forward(
575
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
576
577
        )

578
    def forward_extend(self, forward_batch: ForwardBatch):
579
        self.attn_backend.init_forward_metadata(forward_batch)
580
581
        if self.is_generation:
            return self.model.forward(
582
                forward_batch.input_ids, forward_batch.positions, forward_batch
583
584
585
586
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
587
588
589
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
590
591
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
592

593
594
595
596
597
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
598
        else:
599
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
600

601
602
603
604
605
606
607
608
609
610
611
612
613
614
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
        sampling_info = forward_batch.sampling_info
        sampling_info.update_regex_vocab_mask()
        sampling_info.update_penalties()
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
615
616
617
618
619
620
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
621
            logits.add_(sampling_info.linear_penalties)
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
            logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))

        return logits

Yineng Zhang's avatar
Yineng Zhang committed
637
638
639
640
641
642
643
644
645
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"

646
647
648
649
650
651
652
653

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
654
655
656
657
658
            try:
                module = importlib.import_module(name)
            except Exception as e:
                logger.warning(f"Ignore import error when loading {name}. " f"{e}")
                continue
659
            if hasattr(module, "EntryClass"):
660
                entry = module.EntryClass
661
662
663
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
664
                    for tmp in entry:
665
666
667
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
668
                        model_arch_name_to_cls[tmp.__name__] = tmp
669
                else:
670
671
672
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
673
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
674

675
676
677
678
679
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
680

681
682
683
684
685
686
687
688
689
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
690
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
691
692
693
694
setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
setattr(ModelRegistry, "is_embedding_model", is_embedding_model)