model_runner.py 26.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
import importlib.resources
Shuo Yang's avatar
Shuo Yang committed
21
import json
22
23
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
24
from functools import lru_cache
25
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27

import torch
28
29
30
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
31
32
33
34
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
35
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
36
)
37
from vllm.distributed.parallel_state import in_the_same_node_as
38
from vllm.model_executor.model_loader import get_model
39
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
40

41
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
Shuo Yang's avatar
Shuo Yang committed
42
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
43
44
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
45
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
46
from sglang.srt.layers.sampler import Sampler
47
from sglang.srt.lora.lora_manager import LoRAManager
48
from sglang.srt.managers.schedule_batch import global_server_args_dict
49
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
50
    DoubleSparseTokenToKVPool,
51
52
53
54
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
55
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
57
from sglang.srt.server_args import ServerArgs
58
from sglang.srt.utils import (
59
    enable_show_time_cost,
60
    get_available_gpu_memory,
61
    monkey_patch_vllm_dummy_weight_loader,
62
63
    monkey_patch_vllm_p2p_access_check,
)
64

Ying Sheng's avatar
Ying Sheng committed
65
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
66

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68

class ModelRunner:
69
70
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
    def __init__(
        self,
73
        model_config: ModelConfig,
74
75
76
77
78
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    ):
81
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
84
        self.device = server_args.device
85
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
88
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
89
        self.server_args = server_args
90
91
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
Ke Bao's avatar
Ke Bao committed
92

93
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
94
95
96
97
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
98
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
99
100
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

115
        if self.is_multimodal:
116
            logger.warning(
117
118
119
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
120
            self.mem_fraction_static *= 0.95
121
            # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
Yineng Zhang's avatar
Yineng Zhang committed
122
123
124
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
125
                server_args.disable_radix_cache = True
126

127
128
129
130
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
131
132
            from outlines.caching import disable_cache

133
134
            disable_cache()

135
136
        global_server_args_dict.update(
            {
137
138
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
139
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
140
                "disable_mla": server_args.disable_mla,
141
                "torchao_config": server_args.torchao_config,
142
                "disable_penalizer": server_args.disable_penalizer,
143
                "disable_nan_detection": server_args.disable_nan_detection,
144
145
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
146

147
        # Init componnets
148
        min_per_gpu_memory = self.init_torch_distributed()
149
        self.sampler = Sampler()
150
        self.load_model()
151
152
        if server_args.lora_paths is not None:
            self.init_lora_manager()
153
154
        self.init_memory_pool(
            min_per_gpu_memory,
155
            server_args.max_running_requests,
156
157
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
158
159
160
161
162
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
163
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
164
            self.init_attention_backend()
165
166

    def init_torch_distributed(self):
167
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
168
        # Init torch distributed
Zhang, Liangang's avatar
Zhang, Liangang committed
169
170
171
        if self.device == "cuda":
            torch.cuda.set_device(self.gpu_id)
            backend = "nccl"
172
173
174
175
176
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            torch.xpu.set_device(self.gpu_id)
            backend = "gloo"
177

178
        if not self.server_args.enable_p2p_check:
179
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
180
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
181
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
182
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
183
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
184
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
185
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
186
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
187
188
            world_size=self.tp_size,
            rank=self.tp_rank,
189
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
190
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
191
192
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
193
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
194
            self.device, self.gpu_id, distributed=self.tp_size > 1
195
        )
196
        self.tp_group = get_tp_group()
197

198
199
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
Zhang, Liangang's avatar
Zhang, Liangang committed
200
201
202
        if self.device == "cuda" and not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        ):
203
204
205
206
207
208
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
209
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
210
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
211
            if min_per_gpu_memory < local_gpu_memory * 0.9:
212
213
214
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
215

216
        return min_per_gpu_memory
217

Lianmin Zheng's avatar
Lianmin Zheng committed
218
    def load_model(self):
219
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
220
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
221
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
222
223
224

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
225
226
227
228
229
230
231
232
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
233

Lianmin Zheng's avatar
Lianmin Zheng committed
234
        # Prepare the vllm model config
235
        monkey_patch_vllm_dummy_weight_loader()
236
237
238
239
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
240
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
243
244
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
245
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
246
            dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
247
            seed=self.server_args.random_seed,
248
249
            skip_tokenizer_init=True,
        )
250
        if self.model_config.model_override_args is not None:
251
            self.vllm_model_config.hf_config.update(
252
                self.model_config.model_override_args
253
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
254
        self.dtype = self.vllm_model_config.dtype
255

Lianmin Zheng's avatar
Lianmin Zheng committed
256
        # Load the model
257
        self.model = get_model(
258
259
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Zhang, Liangang's avatar
Zhang, Liangang committed
260
            device_config=DeviceConfig(self.device),
261
262
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
263
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
264
            cache_config=None,
265
        )
266
        self.sliding_window_size = (
267
268
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
269
270
            else None
        )
271

272
        logger.info(
273
            f"Load weight end. "
274
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
275
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
276
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
277
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
278

279
280
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
281
282
283
284
285
286
287
288
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
289
            f"Update weights begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
290
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
291
292
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
293
        target_device = torch.device(self.device)
294
295

        try:
296
            # TODO: Use a better method to check this
297
298
299
300
301
302
303
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
304
                seed=self.server_args.random_seed,
305
306
307
                skip_tokenizer_init=True,
            )
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
308
309
            message = f"Failed to load model config: {e}."
            return False, message
310
311
312
313
314
315

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
316
317
            message = f"Failed to get model loader: {loader}."
            return False, message
318
319
320

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
321
322
323
324
325
326
327
                DefaultModelLoader.Source(
                    config.model,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
344
                message = f"Failed to get weights iterator: {e}."
345
346
347
348
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
349
350
351
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
352
353
354
355
356
357
358
359
360
361
362
363
364
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

365
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
366
        return True, "Succeeded to update model weights."
367

368
369
370
371
372
373
374
375
376
377
378
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

379
    def profile_max_num_token(self, total_gpu_memory: int):
380
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
381
            self.device, self.gpu_id, distributed=self.tp_size > 1
382
        )
383
384
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
385
            and not self.server_args.disable_mla
386
387
388
389
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
390
                * torch._utils._element_size(self.kv_cache_dtype)
391
392
393
394
395
396
397
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
398
                * torch._utils._element_size(self.kv_cache_dtype)
399
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
400
401
402
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
403
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
        return max_num_token

406
    def init_memory_pool(
407
408
        self,
        total_gpu_memory: int,
409
410
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
411
    ):
412
413
414
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
415
            self.kv_cache_dtype = torch.float8_e5m2
416
417
418
419
420
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

421
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
422
423
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
424
                logging.warning(
425
426
427
428
429
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
430

431
        if self.max_total_num_tokens <= 0:
432
            raise RuntimeError(
433
                "Not enough memory. Please try to increase --mem-fraction-static."
434
            )
435

Liangsheng Yin's avatar
Liangsheng Yin committed
436
        if max_num_reqs is None:
437
438
439
440
441
442
443
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
444
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
445
446
447
            )

        self.req_to_token_pool = ReqToTokenPool(
448
449
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
450
            device=self.device,
451
            use_records=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
452
        )
453
454
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
455
            and not self.server_args.disable_mla
456
457
458
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
459
                dtype=self.kv_cache_dtype,
460
461
462
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
463
                device=self.device,
464
            )
Shuo Yang's avatar
Shuo Yang committed
465
466
467
468
469
470
471
472
473
474
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
475
476
477
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
478
                dtype=self.kv_cache_dtype,
479
480
481
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
482
                device=self.device,
483
            )
484
        logger.info(
485
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
486
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
487
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
488

Lianmin Zheng's avatar
Lianmin Zheng committed
489
490
491
492
493
494
495
496
497
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

498
499
500
501
502
503
504
505
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
506
            )
507
            assert not self.model_config.is_encoder_decoder, (
508
509
510
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
511
512
513
514
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
515
        else:
516
517
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
518
            )
519

Shuo Yang's avatar
Shuo Yang committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    def init_double_sparsity_channel_config(self, selected_channel):

        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

538
    def init_cuda_graphs(self):
539
        """Capture cuda graphs."""
540
541
542
543
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

544
545
546
547
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

548
549
        if self.server_args.disable_cuda_graph:
            return
550

551
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
552
        self.cuda_graph_runner = CudaGraphRunner(self)
553

554
    def forward_decode(self, forward_batch: ForwardBatch):
555
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
556
            return self.cuda_graph_runner.replay(forward_batch)
557

558
559
        forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
        self.attn_backend.init_forward_metadata(forward_batch)
560
        return self.model.forward(
561
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
562
563
        )

564
    def forward_extend(self, forward_batch: ForwardBatch):
565
        self.attn_backend.init_forward_metadata(forward_batch)
566
567
        if self.is_generation:
            return self.model.forward(
568
                forward_batch.input_ids, forward_batch.positions, forward_batch
569
570
571
572
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
573
574
575
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
576
577
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
578

579
580
581
582
583
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
584
        else:
585
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
586

587
588
589
590
591
592
593
594
595
596
597
598
599
600
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
        sampling_info = forward_batch.sampling_info
        sampling_info.update_regex_vocab_mask()
        sampling_info.update_penalties()
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
601
602
603
604
605
606
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
607
            logits.add_(sampling_info.linear_penalties)
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
            logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))

        return logits

Yineng Zhang's avatar
Yineng Zhang committed
623
624
625
626
627
628
629
630
631
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"

632
633
634
635
636
637
638
639

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
640
641
642
643
644
            try:
                module = importlib.import_module(name)
            except Exception as e:
                logger.warning(f"Ignore import error when loading {name}. " f"{e}")
                continue
645
            if hasattr(module, "EntryClass"):
646
                entry = module.EntryClass
647
648
649
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
650
                    for tmp in entry:
651
652
653
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
654
                        model_arch_name_to_cls[tmp.__name__] = tmp
655
                else:
656
657
658
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
659
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
660

661
662
663
664
665
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
666

667
668
669
670
671
672
673
674
675
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
676
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
677
678
679
680
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)