model_runner.py 32.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import gc
Cody Yu's avatar
Cody Yu committed
17
import importlib
18
import importlib.resources
19
import inspect
Shuo Yang's avatar
Shuo Yang committed
20
import json
21
22
import logging
import pkgutil
23
import time
Cody Yu's avatar
Cody Yu committed
24
from functools import lru_cache
25
26
from tokenize import tabsize
from typing import Any, Optional, Type, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28

import torch
29
import torch.distributed as dist
30
31
32
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
33
34
35
36
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
37
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
38
)
39
from vllm.distributed.parallel_state import in_the_same_node_as
40
from vllm.model_executor.model_loader import get_model
41
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
42

43
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
Shuo Yang's avatar
Shuo Yang committed
44
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
45
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
46
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
47
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
48
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
49
from sglang.srt.layers.sampler import Sampler
50
from sglang.srt.lora.lora_manager import LoRAManager
51
from sglang.srt.managers.schedule_batch import global_server_args_dict
52
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
53
    DoubleSparseTokenToKVPool,
54
55
56
57
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
58
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
59
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
60
from sglang.srt.server_args import ServerArgs
61
from sglang.srt.utils import (
62
    crash_on_warnings,
63
    enable_show_time_cost,
64
    get_available_gpu_memory,
65
    init_custom_process_group,
HAI's avatar
HAI committed
66
    is_hip,
67
    monkey_patch_vllm_gguf_config,
68
    monkey_patch_vllm_model_config,
69
    monkey_patch_vllm_p2p_access_check,
70
    set_cpu_offload_max_bytes,
71
)
72

Ying Sheng's avatar
Ying Sheng committed
73
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
74

Lianmin Zheng's avatar
Lianmin Zheng committed
75
76

class ModelRunner:
77
78
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
79
80
    def __init__(
        self,
81
        model_config: ModelConfig,
82
83
84
85
86
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
87
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
88
    ):
89
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
90
91
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
92
        self.device = server_args.device
93
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
96
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
97
        self.server_args = server_args
98
99
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
Ke Bao's avatar
Ke Bao committed
100

101
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
102
103
104
105
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
106
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
107
108
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

123
        if self.is_multimodal:
124
            logger.info(
125
126
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
127
            server_args.chunked_prefill_size = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
128
            self.mem_fraction_static *= 0.95
129
            # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
Yineng Zhang's avatar
Yineng Zhang committed
130
131
132
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
133
                server_args.disable_radix_cache = True
134

135
136
137
138
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
139
140
            from outlines.caching import disable_cache

141
142
            disable_cache()

143
144
        global_server_args_dict.update(
            {
145
146
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
147
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
148
                "disable_mla": server_args.disable_mla,
149
                "torchao_config": server_args.torchao_config,
150
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
151
                "enable_dp_attention": server_args.enable_dp_attention,
152
153
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
154

155
156
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

157
        # Get memory before model loading
158
        min_per_gpu_memory = self.init_torch_distributed()
159
160

        # Load the model
161
        self.sampler = Sampler()
162
        self.load_model()
163

164
        # Apply torch TP if the model supports it
165
166
167
168
169
170
171
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

172
        # Init memory pool and attention backends
173
174
        if server_args.lora_paths is not None:
            self.init_lora_manager()
175
176
        self.init_memory_pool(
            min_per_gpu_memory,
177
            server_args.max_running_requests,
178
179
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
180
181
182
183
184
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
185
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
186
            self.init_attention_backend()
187
188

    def init_torch_distributed(self):
189
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
190
        # Init torch distributed
191
        torch.get_device_module(self.device).set_device(self.gpu_id)
Zhang, Liangang's avatar
Zhang, Liangang committed
192
193
        if self.device == "cuda":
            backend = "nccl"
194
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
195
196
197
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            backend = "gloo"
198
199
        elif self.device == "hpu":
            backend = "hccl"
200

201
        if not self.server_args.enable_p2p_check:
202
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
203
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
204
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
205
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
206
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
207
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
208
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
209
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
210
211
            world_size=self.tp_size,
            rank=self.tp_rank,
212
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
213
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
214
215
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
216
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
217
            self.device, self.gpu_id, distributed=self.tp_size > 1
218
        )
219
        self.tp_group = get_tp_group()
220

221
        # Check memory for tensor parallelism
222
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
223
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
224
            if min_per_gpu_memory < local_gpu_memory * 0.9:
225
226
227
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
228

229
        return min_per_gpu_memory
230

231
232
233
234
235
236
237
238
239
240
241
242
243
    def setup_model(self):
        try:
            from vllm.config import VllmConfig

            vllm_config = VllmConfig()
            vllm_config.model_config = self.vllm_model_config
            vllm_config.load_config = self.load_config
            vllm_config.device_config = DeviceConfig(self.device)
            vllm_config.quant_config = VllmConfig._get_quantization_config(
                vllm_config.model_config, vllm_config.load_config
            )
            return get_model(vllm_config=vllm_config)
        except ImportError:
Lianmin Zheng's avatar
Lianmin Zheng committed
244
245
246
247
248
249
250
251
252
253
254
            pass

        return get_model(
            model_config=self.vllm_model_config,
            load_config=self.load_config,
            device_config=DeviceConfig(self.device),
            parallel_config=None,
            scheduler_config=None,
            lora_config=None,
            cache_config=None,
        )
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

    def get_model_config_params(self):
        sig = inspect.signature(VllmModelConfig.__init__)
        params = {
            "model": self.server_args.model_path,
            "quantization": self.server_args.quantization,
            "tokenizer": None,
            "tokenizer_mode": None,
            "trust_remote_code": self.server_args.trust_remote_code,
            "dtype": self.server_args.dtype,
            "seed": self.server_args.random_seed,
            "skip_tokenizer_init": True,
        }

        if "task" in sig.parameters:
            params["task"] = ""

        return params

Lianmin Zheng's avatar
Lianmin Zheng committed
274
    def load_model(self):
275
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
276
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
277
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
278
279
280

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
281
282
283
284
285
286
287
288
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
289

Lianmin Zheng's avatar
Lianmin Zheng committed
290
        # Prepare the vllm model config
291
292
293
294
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
295
        monkey_patch_vllm_model_config()
296
297
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
298
        self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
299
        if self.model_config.model_override_args is not None:
300
            self.vllm_model_config.hf_config.update(
301
                self.model_config.model_override_args
302
            )
303

304
305
        self.model = self.setup_model()

306
        self.sliding_window_size = (
307
308
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
309
310
            else None
        )
311
        self.dtype = self.vllm_model_config.dtype
312

313
        logger.info(
314
            f"Load weight end. "
315
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
316
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
317
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
318
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
319

Chayenne's avatar
Chayenne committed
320
321
    def update_weights_from_disk(self, model_path: str, load_format: str):
        """Update engine weights online from disk."""
322
323
324
325
326
327
328
329
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
Chayenne's avatar
Chayenne committed
330
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
331
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
332
333
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
334
        target_device = torch.device(self.device)
335
336

        try:
337
338
339
            model_config_params = self.get_model_config_params()
            model_config_params["model"] = model_path
            vllm_model_config = VllmModelConfig(**model_config_params)
340
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
            message = f"Failed to load model config: {e}."
            return False, message
343
344
345
346
347
348

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
349
350
            message = f"Failed to get model loader: {loader}."
            return False, message
351
352
353

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
354
355
356
357
358
359
360
                DefaultModelLoader.Source(
                    config.model,
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
377
                message = f"Failed to get weights iterator: {e}."
378
379
380
381
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
384
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
385
386
387
388
389
390
391
392
393
394
395
396
397
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

398
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
399
        return True, "Succeeded to update model weights."
400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
            f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )
        current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

498
499
500
501
502
503
504
505
506
507
508
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

509
    def profile_max_num_token(self, total_gpu_memory: int):
510
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
511
            self.device, self.gpu_id, distributed=self.tp_size > 1
512
        )
513
514
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
515
            and not self.server_args.disable_mla
516
517
518
519
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
520
                * torch._utils._element_size(self.kv_cache_dtype)
521
522
523
524
525
526
527
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
528
                * torch._utils._element_size(self.kv_cache_dtype)
529
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
530
531
532
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
533
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
534
535
        return max_num_token

536
    def init_memory_pool(
537
538
        self,
        total_gpu_memory: int,
539
540
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
541
    ):
542
543
544
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
HAI's avatar
HAI committed
545
546
547
548
            if is_hip():  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
549
550
551
552
553
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

554
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
555
556
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
557
                logging.warning(
558
559
560
561
562
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
563

564
        if self.max_total_num_tokens <= 0:
565
            raise RuntimeError(
566
                "Not enough memory. Please try to increase --mem-fraction-static."
567
            )
568

Liangsheng Yin's avatar
Liangsheng Yin committed
569
        if max_num_reqs is None:
570
571
572
573
574
575
576
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
577
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
578
579
580
            )

        self.req_to_token_pool = ReqToTokenPool(
581
582
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
583
            device=self.device,
584
            use_records=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
585
        )
586
587
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
588
            and not self.server_args.disable_mla
589
590
591
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
592
                dtype=self.kv_cache_dtype,
593
594
595
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
596
                device=self.device,
597
            )
Shuo Yang's avatar
Shuo Yang committed
598
599
600
601
602
603
604
605
606
607
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
608
609
610
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
611
                dtype=self.kv_cache_dtype,
612
613
614
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
615
                device=self.device,
616
            )
617
        logger.info(
618
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
619
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
620
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
621

Lianmin Zheng's avatar
Lianmin Zheng committed
622
623
624
625
626
627
628
629
630
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

631
632
633
634
635
636
637
638
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
639
            )
640
            assert not self.model_config.is_encoder_decoder, (
641
642
643
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
644
645
646
647
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
648
649
        elif self.server_args.attention_backend == "torch_native":
            self.attn_backend = TorchNativeAttnBackend(self)
650
        else:
651
652
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
653
            )
654

Shuo Yang's avatar
Shuo Yang committed
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
    def init_double_sparsity_channel_config(self, selected_channel):

        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

673
    def init_cuda_graphs(self):
674
        """Capture cuda graphs."""
675
676
677
678
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

679
680
681
682
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

683
684
        if self.server_args.disable_cuda_graph:
            return
685

686
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
687
        self.cuda_graph_runner = CudaGraphRunner(self)
688

689
690
691
692
693
694
695
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

696
    def forward_decode(self, forward_batch: ForwardBatch):
697
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
698
            return self.cuda_graph_runner.replay(forward_batch)
699

700
701
        forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
        self.attn_backend.init_forward_metadata(forward_batch)
702
        return self.model.forward(
703
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
704
705
        )

706
    def forward_extend(self, forward_batch: ForwardBatch):
707
        self.attn_backend.init_forward_metadata(forward_batch)
708
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
709
710
711
712
713
714
715
716
717
718
719
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
720
721
722
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
723
724
725
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
726
727
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
728

Ke Bao's avatar
Ke Bao committed
729
    def forward_idle(self, forward_batch: ForwardBatch):
730
731
732
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
            return self.cuda_graph_runner.replay(forward_batch)

Ke Bao's avatar
Ke Bao committed
733
734
735
736
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

737
738
739
740
741
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
742
743
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
744
        else:
745
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
746

747
748
749
750
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        sampling_info = forward_batch.sampling_info
751
752
753
754
755
756
757
758
759
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
            sampling_info.update_penalties()
760
761
762
763
764
765
766
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
767
768
769
770
771
772
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
773
            logits.add_(sampling_info.linear_penalties)
774
775
776
777
778
779
780
781
782
783
784

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
785
            sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
786
787
788

        return logits

Yineng Zhang's avatar
Yineng Zhang committed
789
790
791
792
793
794
795
796
797
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"

798
799
800
801
802
803
804
805

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
806
807
808
            try:
                module = importlib.import_module(name)
            except Exception as e:
809
810
811
                logger.warning(f"Ignore import error when loading {name}. {e}")
                if crash_on_warnings():
                    raise ValueError(f"Ignore import error when loading {name}. {e}")
812
                continue
813
            if hasattr(module, "EntryClass"):
814
                entry = module.EntryClass
815
816
817
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
818
                    for tmp in entry:
819
820
821
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
822
                        model_arch_name_to_cls[tmp.__name__] = tmp
823
                else:
824
825
826
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
827
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
828

829
830
831
832
833
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
834

835
836
837
838
839
840
841
842
843
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
844
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
845
846
847
848
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)