model_runner.py 27.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
import importlib.resources
Shuo Yang's avatar
Shuo Yang committed
21
import json
22
23
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
24
from functools import lru_cache
25
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27

import torch
28
29
30
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
31
32
33
34
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
35
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
36
)
37
from vllm.distributed.parallel_state import in_the_same_node_as
38
from vllm.model_executor.model_loader import get_model
39
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
40

41
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
Shuo Yang's avatar
Shuo Yang committed
42
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
43
44
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
45
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
46
from sglang.srt.layers.sampler import Sampler
47
from sglang.srt.lora.lora_manager import LoRAManager
48
from sglang.srt.managers.schedule_batch import global_server_args_dict
49
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
50
    DoubleSparseTokenToKVPool,
51
52
53
54
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
55
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
57
from sglang.srt.server_args import ServerArgs
58
from sglang.srt.utils import (
59
    enable_show_time_cost,
60
61
62
    get_available_gpu_memory,
    monkey_patch_vllm_p2p_access_check,
)
63

Ying Sheng's avatar
Ying Sheng committed
64
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
65

Lianmin Zheng's avatar
Lianmin Zheng committed
66
67

class ModelRunner:
68
69
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
    def __init__(
        self,
72
        model_config: ModelConfig,
73
74
75
76
77
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    ):
80
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
83
        self.device = server_args.device
84
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
85
86
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
87
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
88
        self.server_args = server_args
89
90
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
Ke Bao's avatar
Ke Bao committed
91

92
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
93
94
95
96
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
97
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
98
99
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

114
        if self.is_multimodal:
115
            logger.warning(
116
117
118
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
119
            self.mem_fraction_static *= 0.95
120
            # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
Yineng Zhang's avatar
Yineng Zhang committed
121
122
123
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
124
                server_args.disable_radix_cache = True
125

126
127
128
129
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
130
131
            from outlines.caching import disable_cache

132
133
            disable_cache()

134
135
        global_server_args_dict.update(
            {
136
137
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
138
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
139
                "disable_mla": server_args.disable_mla,
140
                "torchao_config": server_args.torchao_config,
141
                "disable_penalizer": server_args.disable_penalizer,
142
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
143
                "enable_dp_attention": server_args.enable_dp_attention,
144
145
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
146

147
        # Init componnets
148
        min_per_gpu_memory = self.init_torch_distributed()
149
        self.sampler = Sampler()
150
        self.load_model()
151
152
153
154
155
156
157
158
159

        # Apply torch TP if model supports it
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

160
161
        if server_args.lora_paths is not None:
            self.init_lora_manager()
162
163
        self.init_memory_pool(
            min_per_gpu_memory,
164
            server_args.max_running_requests,
165
166
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
167
168
169
170
171
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
172
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
173
            self.init_attention_backend()
174
175

    def init_torch_distributed(self):
176
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
177
        # Init torch distributed
Zhang, Liangang's avatar
Zhang, Liangang committed
178
179
180
        if self.device == "cuda":
            torch.cuda.set_device(self.gpu_id)
            backend = "nccl"
181
182
183
184
185
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            torch.xpu.set_device(self.gpu_id)
            backend = "gloo"
186

187
        if not self.server_args.enable_p2p_check:
188
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
189
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
190
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
191
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
192
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
193
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
194
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
195
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
196
197
            world_size=self.tp_size,
            rank=self.tp_rank,
198
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
199
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
200
201
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
202
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
203
            self.device, self.gpu_id, distributed=self.tp_size > 1
204
        )
205
        self.tp_group = get_tp_group()
206

207
208
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
Zhang, Liangang's avatar
Zhang, Liangang committed
209
210
211
        if self.device == "cuda" and not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        ):
212
213
214
215
216
217
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
218
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
219
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
220
            if min_per_gpu_memory < local_gpu_memory * 0.9:
221
222
223
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
224

225
        return min_per_gpu_memory
226

Lianmin Zheng's avatar
Lianmin Zheng committed
227
    def load_model(self):
228
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
229
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
230
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
231
232
233

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
234
235
236
237
238
239
240
241
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
242

Lianmin Zheng's avatar
Lianmin Zheng committed
243
        # Prepare the vllm model config
244
245
246
247
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
248
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
249
250
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
251
252
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
253
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
254
            dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
255
            seed=self.server_args.random_seed,
256
257
            skip_tokenizer_init=True,
        )
258
        if self.model_config.model_override_args is not None:
259
            self.vllm_model_config.hf_config.update(
260
                self.model_config.model_override_args
261
            )
262

Lianmin Zheng's avatar
Lianmin Zheng committed
263
        # Load the model
264
        self.model = get_model(
265
266
267
268
269
270
271
            model_config=self.vllm_model_config,
            load_config=self.load_config,
            device_config=DeviceConfig(self.device),
            parallel_config=None,
            scheduler_config=None,
            lora_config=None,
            cache_config=None,
272
        )
273
        self.sliding_window_size = (
274
275
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
276
277
            else None
        )
278
        self.dtype = self.vllm_model_config.dtype
279

280
        logger.info(
281
            f"Load weight end. "
282
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
283
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
284
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
285
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
286

287
288
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
289
290
291
292
293
294
295
296
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
297
            f"Update weights begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
298
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
299
300
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
301
        target_device = torch.device(self.device)
302
303

        try:
304
            # TODO: Use a better method to check this
305
306
307
308
309
310
311
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
312
                seed=self.server_args.random_seed,
313
314
315
                skip_tokenizer_init=True,
            )
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
316
317
            message = f"Failed to load model config: {e}."
            return False, message
318
319
320
321
322
323

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
324
325
            message = f"Failed to get model loader: {loader}."
            return False, message
326
327
328

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
329
330
331
332
333
334
335
                DefaultModelLoader.Source(
                    config.model,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
352
                message = f"Failed to get weights iterator: {e}."
353
354
355
356
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
359
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
360
361
362
363
364
365
366
367
368
369
370
371
372
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

373
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
374
        return True, "Succeeded to update model weights."
375

376
377
378
379
380
381
382
383
384
385
386
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

387
    def profile_max_num_token(self, total_gpu_memory: int):
388
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
389
            self.device, self.gpu_id, distributed=self.tp_size > 1
390
        )
391
392
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
393
            and not self.server_args.disable_mla
394
395
396
397
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
398
                * torch._utils._element_size(self.kv_cache_dtype)
399
400
401
402
403
404
405
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
406
                * torch._utils._element_size(self.kv_cache_dtype)
407
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
408
409
410
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
411
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
412
413
        return max_num_token

414
    def init_memory_pool(
415
416
        self,
        total_gpu_memory: int,
417
418
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
419
    ):
420
421
422
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
423
            self.kv_cache_dtype = torch.float8_e5m2
424
425
426
427
428
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

429
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
430
431
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
432
                logging.warning(
433
434
435
436
437
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
438

439
        if self.max_total_num_tokens <= 0:
440
            raise RuntimeError(
441
                "Not enough memory. Please try to increase --mem-fraction-static."
442
            )
443

Liangsheng Yin's avatar
Liangsheng Yin committed
444
        if max_num_reqs is None:
445
446
447
448
449
450
451
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
452
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
453
454
455
            )

        self.req_to_token_pool = ReqToTokenPool(
456
457
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
458
            device=self.device,
459
            use_records=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
460
        )
461
462
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
463
            and not self.server_args.disable_mla
464
465
466
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
467
                dtype=self.kv_cache_dtype,
468
469
470
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
471
                device=self.device,
472
            )
Shuo Yang's avatar
Shuo Yang committed
473
474
475
476
477
478
479
480
481
482
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
483
484
485
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
486
                dtype=self.kv_cache_dtype,
487
488
489
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
490
                device=self.device,
491
            )
492
        logger.info(
493
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
494
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
495
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
496

Lianmin Zheng's avatar
Lianmin Zheng committed
497
498
499
500
501
502
503
504
505
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

506
507
508
509
510
511
512
513
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
514
            )
515
            assert not self.model_config.is_encoder_decoder, (
516
517
518
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
519
520
521
522
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
523
        else:
524
525
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
526
            )
527

Shuo Yang's avatar
Shuo Yang committed
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    def init_double_sparsity_channel_config(self, selected_channel):

        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

546
    def init_cuda_graphs(self):
547
        """Capture cuda graphs."""
548
549
550
551
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

552
553
554
555
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

556
557
        if self.server_args.disable_cuda_graph:
            return
558

559
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
560
        self.cuda_graph_runner = CudaGraphRunner(self)
561

562
563
564
565
566
567
568
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

569
    def forward_decode(self, forward_batch: ForwardBatch):
570
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
571
            return self.cuda_graph_runner.replay(forward_batch)
572

573
574
        forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
        self.attn_backend.init_forward_metadata(forward_batch)
575
        return self.model.forward(
576
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
577
578
        )

579
    def forward_extend(self, forward_batch: ForwardBatch):
580
        self.attn_backend.init_forward_metadata(forward_batch)
581
582
        if self.is_generation:
            return self.model.forward(
583
                forward_batch.input_ids, forward_batch.positions, forward_batch
584
585
586
587
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
588
589
590
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
591
592
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
593

Ke Bao's avatar
Ke Bao committed
594
    def forward_idle(self, forward_batch: ForwardBatch):
595
596
597
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
            return self.cuda_graph_runner.replay(forward_batch)

Ke Bao's avatar
Ke Bao committed
598
599
600
601
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

602
603
604
605
606
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
607
608
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
609
        else:
610
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
611

612
613
614
615
616
617
618
619
620
621
622
623
624
625
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
        sampling_info = forward_batch.sampling_info
        sampling_info.update_regex_vocab_mask()
        sampling_info.update_penalties()
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
626
627
628
629
630
631
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
632
            logits.add_(sampling_info.linear_penalties)
633
634
635
636
637
638
639
640
641
642
643

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
644
            sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
645
646
647

        return logits

Yineng Zhang's avatar
Yineng Zhang committed
648
649
650
651
652
653
654
655
656
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"

657
658
659
660
661
662
663
664

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
665
666
667
668
669
            try:
                module = importlib.import_module(name)
            except Exception as e:
                logger.warning(f"Ignore import error when loading {name}. " f"{e}")
                continue
670
            if hasattr(module, "EntryClass"):
671
                entry = module.EntryClass
672
673
674
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
675
                    for tmp in entry:
676
677
678
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
679
                        model_arch_name_to_cls[tmp.__name__] = tmp
680
                else:
681
682
683
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
684
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
685

686
687
688
689
690
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
691

692
693
694
695
696
697
698
699
700
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
701
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
702
703
704
705
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)