model_runner.py 26.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
import importlib.resources
Shuo Yang's avatar
Shuo Yang committed
21
import json
22
23
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
24
from functools import lru_cache
25
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27

import torch
28
29
30
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
31
32
33
34
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
35
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
36
)
37
from vllm.distributed.parallel_state import in_the_same_node_as
38
from vllm.model_executor.model_loader import get_model
39
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
40

41
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
42
from sglang.srt.constrained import disable_cache
Shuo Yang's avatar
Shuo Yang committed
43
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
44
45
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
46
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
47
from sglang.srt.layers.sampler import Sampler
48
from sglang.srt.lora.lora_manager import LoRAManager
49
from sglang.srt.managers.schedule_batch import global_server_args_dict
50
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
51
    DoubleSparseTokenToKVPool,
52
53
54
55
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
56
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
58
from sglang.srt.server_args import ServerArgs
59
from sglang.srt.utils import (
60
    enable_show_time_cost,
61
    get_available_gpu_memory,
62
63
    is_attention_free_model,
    is_embedding_model,
64
    is_generation_model,
65
    is_multimodal_model,
66
    model_has_inner_state,
67
    monkey_patch_vllm_dummy_weight_loader,
68
69
    monkey_patch_vllm_p2p_access_check,
)
70

Ying Sheng's avatar
Ying Sheng committed
71
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
72

Lianmin Zheng's avatar
Lianmin Zheng committed
73
74

class ModelRunner:
75
76
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
    def __init__(
        self,
79
        model_config: ModelConfig,
80
81
82
83
84
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
85
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
86
    ):
87
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
90
        self.device = server_args.device
91
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
94
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
95
        self.server_args = server_args
96
97
98
        self.is_multimodal_model = is_multimodal_model(
            self.model_config.hf_config.architectures
        )
Ke Bao's avatar
Ke Bao committed
99

100
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
101
102
103
104
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
105
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
106
107
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

122
        if self.is_multimodal_model:
123
            logger.warning(
124
125
126
127
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
            server_args.mem_fraction_static *= 0.95
Yineng Zhang's avatar
Yineng Zhang committed
128
129
130
131
132
            # TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
                server_args.disable_cuda_graph = True
133

134
135
136
137
138
139
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
            disable_cache()

140
141
        global_server_args_dict.update(
            {
142
143
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
144
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
145
                "disable_mla": server_args.disable_mla,
146
                "torchao_config": server_args.torchao_config,
147
                "disable_penalizer": server_args.disable_penalizer,
148
                "disable_nan_detection": server_args.disable_nan_detection,
149
150
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
151

152
        # Init componnets
153
        min_per_gpu_memory = self.init_torch_distributed()
154
        self.sampler = Sampler()
155
        self.load_model()
156
157
        if server_args.lora_paths is not None:
            self.init_lora_manager()
158
159
        self.init_memory_pool(
            min_per_gpu_memory,
160
            server_args.max_running_requests,
161
162
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
163
164
165
166
167
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
168
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
169
            self.init_attention_backend()
170
171

    def init_torch_distributed(self):
172
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
173
        # Init torch distributed
Zhang, Liangang's avatar
Zhang, Liangang committed
174
175
176
        if self.device == "cuda":
            torch.cuda.set_device(self.gpu_id)
            backend = "nccl"
177
178
179
180
181
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            torch.xpu.set_device(self.gpu_id)
            backend = "gloo"
182

183
        if not self.server_args.enable_p2p_check:
184
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
185
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
186
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
187
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
188
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
189
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
190
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
191
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
192
193
            world_size=self.tp_size,
            rank=self.tp_rank,
194
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
195
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
196
197
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
198
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
199
            self.device, self.gpu_id, distributed=self.tp_size > 1
200
        )
201
        self.tp_group = get_tp_group()
202

203
204
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
Zhang, Liangang's avatar
Zhang, Liangang committed
205
206
207
        if self.device == "cuda" and not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        ):
208
209
210
211
212
213
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
214
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
215
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
216
            if min_per_gpu_memory < local_gpu_memory * 0.9:
217
218
219
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
220

221
        return min_per_gpu_memory
222

Lianmin Zheng's avatar
Lianmin Zheng committed
223
    def load_model(self):
224
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
225
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
226
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
227
228
229

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
230
231
232
233
234
235
236
237
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
238

Lianmin Zheng's avatar
Lianmin Zheng committed
239
        # Prepare the vllm model config
240
        monkey_patch_vllm_dummy_weight_loader()
241
242
        self.load_config = LoadConfig(load_format=self.server_args.load_format)
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
243
244
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
245
246
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
247
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
248
            dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
249
            seed=self.server_args.random_seed,
250
251
            skip_tokenizer_init=True,
        )
252
        if self.model_config.model_override_args is not None:
253
            self.vllm_model_config.hf_config.update(
254
                self.model_config.model_override_args
255
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
256
        self.dtype = self.vllm_model_config.dtype
257

Lianmin Zheng's avatar
Lianmin Zheng committed
258
        # Load the model
259
        self.model = get_model(
260
261
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Zhang, Liangang's avatar
Zhang, Liangang committed
262
            device_config=DeviceConfig(self.device),
263
264
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
265
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
266
            cache_config=None,
267
        )
268
        self.sliding_window_size = (
269
270
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
271
272
            else None
        )
273
        self.is_generation = is_generation_model(
274
            self.model_config.hf_config.architectures, self.server_args.is_embedding
275
276
        )

277
        logger.info(
278
            f"Load weight end. "
279
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
280
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
281
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
282
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
283

284
285
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
286
287
288
289
290
291
292
293
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
294
            f"Update weights begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
295
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
296
297
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
298
        target_device = torch.device(self.device)
299
300

        try:
301
            # TODO: Use a better method to check this
302
303
304
305
306
307
308
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
309
                seed=self.server_args.random_seed,
310
311
312
                skip_tokenizer_init=True,
            )
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
313
314
            message = f"Failed to load model config: {e}."
            return False, message
315
316
317
318
319
320

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
321
322
            message = f"Failed to get model loader: {loader}."
            return False, message
323
324
325

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
326
327
328
329
330
331
332
                DefaultModelLoader.Source(
                    config.model,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
349
                message = f"Failed to get weights iterator: {e}."
350
351
352
353
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
354
355
356
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
357
358
359
360
361
362
363
364
365
366
367
368
369
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

370
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
371
        return True, "Succeeded to update model weights."
372

373
374
375
376
377
378
379
380
381
382
383
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

384
    def profile_max_num_token(self, total_gpu_memory: int):
385
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
386
            self.device, self.gpu_id, distributed=self.tp_size > 1
387
        )
388
389
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
390
            and not self.server_args.disable_mla
391
392
393
394
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
395
                * torch._utils._element_size(self.kv_cache_dtype)
396
397
398
399
400
401
402
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
403
                * torch._utils._element_size(self.kv_cache_dtype)
404
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
405
406
407
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
408
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
409
410
        return max_num_token

411
    def init_memory_pool(
412
413
        self,
        total_gpu_memory: int,
414
415
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
416
    ):
417
418
419
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
420
            self.kv_cache_dtype = torch.float8_e5m2
421
422
423
424
425
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

426
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
427
428
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
429
                logging.warning(
430
431
432
433
434
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
435

436
        if self.max_total_num_tokens <= 0:
437
            raise RuntimeError(
438
                "Not enough memory. Please try to increase --mem-fraction-static."
439
            )
440

Liangsheng Yin's avatar
Liangsheng Yin committed
441
        if max_num_reqs is None:
442
443
444
445
446
447
448
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
449
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
450
451
452
            )

        self.req_to_token_pool = ReqToTokenPool(
453
454
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
455
            device=self.device,
456
            use_records=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
457
        )
458
459
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
460
            and not self.server_args.disable_mla
461
462
463
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
464
                dtype=self.kv_cache_dtype,
465
466
467
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
468
                device=self.device,
469
            )
Shuo Yang's avatar
Shuo Yang committed
470
471
472
473
474
475
476
477
478
479
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
480
481
482
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
483
                dtype=self.kv_cache_dtype,
484
485
486
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
487
                device=self.device,
488
            )
489
        logger.info(
490
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
491
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
492
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
493

Lianmin Zheng's avatar
Lianmin Zheng committed
494
495
496
497
498
499
500
501
502
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

503
504
505
506
507
508
509
510
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
511
            )
512
            assert not self.model_config.is_encoder_decoder, (
513
514
515
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
516
517
518
519
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
520
        else:
521
522
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
523
            )
524

Shuo Yang's avatar
Shuo Yang committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    def init_double_sparsity_channel_config(self, selected_channel):

        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

543
    def init_cuda_graphs(self):
544
        """Capture cuda graphs."""
545
546
547
548
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

549
550
551
552
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

553
554
        if self.server_args.disable_cuda_graph:
            return
555

556
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
557
        self.cuda_graph_runner = CudaGraphRunner(self)
558

559
    def forward_decode(self, forward_batch: ForwardBatch):
560
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
561
            return self.cuda_graph_runner.replay(forward_batch)
562

563
564
        forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
        self.attn_backend.init_forward_metadata(forward_batch)
565
        return self.model.forward(
566
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
567
568
        )

569
    def forward_extend(self, forward_batch: ForwardBatch):
570
        self.attn_backend.init_forward_metadata(forward_batch)
571
572
        if self.is_generation:
            return self.model.forward(
573
                forward_batch.input_ids, forward_batch.positions, forward_batch
574
575
576
577
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
578
579
580
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
581
582
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
583

584
585
586
587
588
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
589
        else:
590
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
591

592
593
594
595
596
597
598
599
600
601
602
603
604
605
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
        sampling_info = forward_batch.sampling_info
        sampling_info.update_regex_vocab_mask()
        sampling_info.update_penalties()
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
606
607
608
609
610
611
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
612
            logits.add_(sampling_info.linear_penalties)
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
            logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))

        return logits

Yineng Zhang's avatar
Yineng Zhang committed
628
629
630
631
632
633
634
635
636
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"

637
638
639
640
641
642
643
644

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
645
646
647
648
649
            try:
                module = importlib.import_module(name)
            except Exception as e:
                logger.warning(f"Ignore import error when loading {name}. " f"{e}")
                continue
650
            if hasattr(module, "EntryClass"):
651
                entry = module.EntryClass
652
653
654
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
655
                    for tmp in entry:
656
657
658
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
659
                        model_arch_name_to_cls[tmp.__name__] = tmp
660
                else:
661
662
663
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
664
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
665

666
667
668
669
670
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
671

672
673
674
675
676
677
678
679
680
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
681
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
682
683
684
685
setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
setattr(ModelRegistry, "is_embedding_model", is_embedding_model)