model_runner.py 26.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
import importlib.resources
Shuo Yang's avatar
Shuo Yang committed
21
import json
22
23
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
24
from functools import lru_cache
25
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27

import torch
28
29
30
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
31
32
33
34
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
35
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
36
)
37
from vllm.distributed.parallel_state import in_the_same_node_as
38
from vllm.model_executor.model_loader import get_model
39
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
40

41
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
Shuo Yang's avatar
Shuo Yang committed
42
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
43
44
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
45
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
46
from sglang.srt.layers.sampler import Sampler
47
from sglang.srt.lora.lora_manager import LoRAManager
48
from sglang.srt.managers.schedule_batch import global_server_args_dict
49
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
50
    DoubleSparseTokenToKVPool,
51
52
53
54
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
55
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
57
from sglang.srt.server_args import ServerArgs
58
from sglang.srt.utils import (
59
    enable_show_time_cost,
60
    get_available_gpu_memory,
61
    monkey_patch_vllm_dummy_weight_loader,
62
63
    monkey_patch_vllm_p2p_access_check,
)
64

Ying Sheng's avatar
Ying Sheng committed
65
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
66

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68

class ModelRunner:
69
70
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
    def __init__(
        self,
73
        model_config: ModelConfig,
74
75
76
77
78
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    ):
81
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
84
        self.device = server_args.device
85
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
88
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
89
        self.server_args = server_args
90
91
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
Ke Bao's avatar
Ke Bao committed
92

93
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
94
95
96
97
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
98
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
99
100
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

115
        if self.is_multimodal:
116
            logger.warning(
117
118
119
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
120
            self.mem_fraction_static *= 0.95
121
            # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
Yineng Zhang's avatar
Yineng Zhang committed
122
123
124
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
125
                server_args.disable_radix_cache = True
126

127
128
129
130
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
131
132
            from outlines.caching import disable_cache

133
134
            disable_cache()

135
136
        global_server_args_dict.update(
            {
137
138
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
139
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
140
                "disable_mla": server_args.disable_mla,
141
                "torchao_config": server_args.torchao_config,
142
                "disable_penalizer": server_args.disable_penalizer,
143
                "disable_nan_detection": server_args.disable_nan_detection,
144
145
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
146

147
        # Init componnets
148
        min_per_gpu_memory = self.init_torch_distributed()
149
        self.sampler = Sampler()
150
        self.load_model()
151
152
        if server_args.lora_paths is not None:
            self.init_lora_manager()
153
154
        self.init_memory_pool(
            min_per_gpu_memory,
155
            server_args.max_running_requests,
156
157
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
158
159
160
161
162
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
163
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
164
            self.init_attention_backend()
165
166

    def init_torch_distributed(self):
167
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
168
        # Init torch distributed
Zhang, Liangang's avatar
Zhang, Liangang committed
169
170
171
        if self.device == "cuda":
            torch.cuda.set_device(self.gpu_id)
            backend = "nccl"
172
173
174
175
176
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            torch.xpu.set_device(self.gpu_id)
            backend = "gloo"
177

178
        if not self.server_args.enable_p2p_check:
179
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
180
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
181
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
182
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
183
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
184
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
185
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
186
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
187
188
            world_size=self.tp_size,
            rank=self.tp_rank,
189
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
190
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
191
192
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
193
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
194
            self.device, self.gpu_id, distributed=self.tp_size > 1
195
        )
196
        self.tp_group = get_tp_group()
197

198
199
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
Zhang, Liangang's avatar
Zhang, Liangang committed
200
201
202
        if self.device == "cuda" and not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        ):
203
204
205
206
207
208
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
209
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
210
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
211
            if min_per_gpu_memory < local_gpu_memory * 0.9:
212
213
214
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
215

216
        return min_per_gpu_memory
217

Lianmin Zheng's avatar
Lianmin Zheng committed
218
    def load_model(self):
219
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
220
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
221
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
222
223
224

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
225
226
227
228
229
230
231
232
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
233

Lianmin Zheng's avatar
Lianmin Zheng committed
234
        # Prepare the vllm model config
235
        monkey_patch_vllm_dummy_weight_loader()
236
237
        self.load_config = LoadConfig(load_format=self.server_args.load_format)
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
238
239
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
240
241
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
242
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
243
            dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
244
            seed=self.server_args.random_seed,
245
246
            skip_tokenizer_init=True,
        )
247
        if self.model_config.model_override_args is not None:
248
            self.vllm_model_config.hf_config.update(
249
                self.model_config.model_override_args
250
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
251
        self.dtype = self.vllm_model_config.dtype
252

Lianmin Zheng's avatar
Lianmin Zheng committed
253
        # Load the model
254
        self.model = get_model(
255
256
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Zhang, Liangang's avatar
Zhang, Liangang committed
257
            device_config=DeviceConfig(self.device),
258
259
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
260
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
261
            cache_config=None,
262
        )
263
        self.sliding_window_size = (
264
265
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
266
267
            else None
        )
268

269
        logger.info(
270
            f"Load weight end. "
271
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
272
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
273
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
274
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
275

276
277
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
278
279
280
281
282
283
284
285
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
286
            f"Update weights begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
287
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
288
289
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
290
        target_device = torch.device(self.device)
291
292

        try:
293
            # TODO: Use a better method to check this
294
295
296
297
298
299
300
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
301
                seed=self.server_args.random_seed,
302
303
304
                skip_tokenizer_init=True,
            )
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
305
306
            message = f"Failed to load model config: {e}."
            return False, message
307
308
309
310
311
312

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
313
314
            message = f"Failed to get model loader: {loader}."
            return False, message
315
316
317

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
318
319
320
321
322
323
324
                DefaultModelLoader.Source(
                    config.model,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
341
                message = f"Failed to get weights iterator: {e}."
342
343
344
345
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
348
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
349
350
351
352
353
354
355
356
357
358
359
360
361
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

362
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
363
        return True, "Succeeded to update model weights."
364

365
366
367
368
369
370
371
372
373
374
375
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

376
    def profile_max_num_token(self, total_gpu_memory: int):
377
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
378
            self.device, self.gpu_id, distributed=self.tp_size > 1
379
        )
380
381
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
382
            and not self.server_args.disable_mla
383
384
385
386
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
387
                * torch._utils._element_size(self.kv_cache_dtype)
388
389
390
391
392
393
394
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
395
                * torch._utils._element_size(self.kv_cache_dtype)
396
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
397
398
399
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
400
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
401
402
        return max_num_token

403
    def init_memory_pool(
404
405
        self,
        total_gpu_memory: int,
406
407
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
408
    ):
409
410
411
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
412
            self.kv_cache_dtype = torch.float8_e5m2
413
414
415
416
417
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

418
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
419
420
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
421
                logging.warning(
422
423
424
425
426
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
427

428
        if self.max_total_num_tokens <= 0:
429
            raise RuntimeError(
430
                "Not enough memory. Please try to increase --mem-fraction-static."
431
            )
432

Liangsheng Yin's avatar
Liangsheng Yin committed
433
        if max_num_reqs is None:
434
435
436
437
438
439
440
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
441
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
442
443
444
            )

        self.req_to_token_pool = ReqToTokenPool(
445
446
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
447
            device=self.device,
448
            use_records=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
449
        )
450
451
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
452
            and not self.server_args.disable_mla
453
454
455
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
456
                dtype=self.kv_cache_dtype,
457
458
459
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
460
                device=self.device,
461
            )
Shuo Yang's avatar
Shuo Yang committed
462
463
464
465
466
467
468
469
470
471
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
472
473
474
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
475
                dtype=self.kv_cache_dtype,
476
477
478
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
479
                device=self.device,
480
            )
481
        logger.info(
482
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
483
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
484
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
485

Lianmin Zheng's avatar
Lianmin Zheng committed
486
487
488
489
490
491
492
493
494
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

495
496
497
498
499
500
501
502
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
503
            )
504
            assert not self.model_config.is_encoder_decoder, (
505
506
507
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
508
509
510
511
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
512
        else:
513
514
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
515
            )
516

Shuo Yang's avatar
Shuo Yang committed
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    def init_double_sparsity_channel_config(self, selected_channel):

        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

535
    def init_cuda_graphs(self):
536
        """Capture cuda graphs."""
537
538
539
540
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

541
542
543
544
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

545
546
        if self.server_args.disable_cuda_graph:
            return
547

548
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
549
        self.cuda_graph_runner = CudaGraphRunner(self)
550

551
    def forward_decode(self, forward_batch: ForwardBatch):
552
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
553
            return self.cuda_graph_runner.replay(forward_batch)
554

555
556
        forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
        self.attn_backend.init_forward_metadata(forward_batch)
557
        return self.model.forward(
558
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
559
560
        )

561
    def forward_extend(self, forward_batch: ForwardBatch):
562
        self.attn_backend.init_forward_metadata(forward_batch)
563
564
        if self.is_generation:
            return self.model.forward(
565
                forward_batch.input_ids, forward_batch.positions, forward_batch
566
567
568
569
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
570
571
572
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
573
574
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
575

576
577
578
579
580
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
581
        else:
582
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
        sampling_info = forward_batch.sampling_info
        sampling_info.update_regex_vocab_mask()
        sampling_info.update_penalties()
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
598
599
600
601
602
603
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
604
            logits.add_(sampling_info.linear_penalties)
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
            logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))

        return logits

Yineng Zhang's avatar
Yineng Zhang committed
620
621
622
623
624
625
626
627
628
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"

629
630
631
632
633
634
635
636

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
637
638
639
640
641
            try:
                module = importlib.import_module(name)
            except Exception as e:
                logger.warning(f"Ignore import error when loading {name}. " f"{e}")
                continue
642
            if hasattr(module, "EntryClass"):
643
                entry = module.EntryClass
644
645
646
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
647
                    for tmp in entry:
648
649
650
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
651
                        model_arch_name_to_cls[tmp.__name__] = tmp
652
                else:
653
654
655
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
656
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
657

658
659
660
661
662
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
663

664
665
666
667
668
669
670
671
672
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
673
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
674
675
676
677
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)