model_runner.py 29.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""ModelRunner runs the forward passes of the models."""
15

16
import gc
Shuo Yang's avatar
Shuo Yang committed
17
import json
18
import logging
19
import time
20
from typing import List, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22

import torch
23
import torch.distributed as dist
zhyncs's avatar
zhyncs committed
24
25
26
27
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
28
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
29
)
Lianmin Zheng's avatar
Lianmin Zheng committed
30

31
32
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
33
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
Shuo Yang's avatar
Shuo Yang committed
34
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
35
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
36
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
37
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
38
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
39
from sglang.srt.layers.sampler import Sampler
40
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
41
from sglang.srt.lora.lora_manager import LoRAManager
42
from sglang.srt.managers.schedule_batch import global_server_args_dict
43
from sglang.srt.mem_cache.memory_pool import (
Shuo Yang's avatar
Shuo Yang committed
44
    DoubleSparseTokenToKVPool,
45
46
47
48
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
49
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
from sglang.srt.model_loader import get_model
Lianmin Zheng's avatar
Lianmin Zheng committed
51
from sglang.srt.server_args import ServerArgs
52
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
53
from sglang.srt.utils import (
54
    enable_show_time_cost,
55
    get_available_gpu_memory,
56
    init_custom_process_group,
HAI's avatar
HAI committed
57
    is_hip,
58
    monkey_patch_vllm_gguf_config,
59
    monkey_patch_vllm_p2p_access_check,
60
    set_cpu_offload_max_bytes,
61
)
62

Ying Sheng's avatar
Ying Sheng committed
63
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
64

Lianmin Zheng's avatar
Lianmin Zheng committed
65
66

class ModelRunner:
67
68
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
69
70
    def __init__(
        self,
71
        model_config: ModelConfig,
72
73
74
75
76
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
        server_args: ServerArgs,
78
        is_draft_worker: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    ):
80
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
83
        self.device = server_args.device
84
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
85
86
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
87
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
88
        self.server_args = server_args
89
        self.is_draft_worker = is_draft_worker
90
91
        self.is_generation = model_config.is_generation
        self.is_multimodal = model_config.is_multimodal
92
        self.should_log = tp_rank == 0
93
94
95
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
Ke Bao's avatar
Ke Bao committed
96

97
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
98
99
100
101
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
102
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
103
104
            self.server_args.attention_backend = "triton"

Shuo Yang's avatar
Shuo Yang committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        if self.server_args.enable_double_sparsity:
            logger.info(
                "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
            )
            self.server_args.attention_backend = "triton"
            self.server_args.disable_cuda_graph = True
            if self.server_args.ds_heavy_channel_type is None:
                raise ValueError(
                    "Please specify the heavy channel type for double sparsity optimization."
                )
            self.init_double_sparsity_channel_config(
                self.server_args.ds_heavy_channel_type
            )

119
        if self.is_multimodal:
Lianmin Zheng's avatar
Lianmin Zheng committed
120
            self.mem_fraction_static *= 0.95
121
122
123
124
125
            logger.info(
                f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
                f"because this is a multimodal model."
            )

126
127
128
129
130
            if self.model_config.hf_config.architectures == [
                "MllamaForConditionalGeneration"
            ]:
                logger.info("Automatically turn off --chunked-prefill-size for mllama.")
                server_args.chunked_prefill_size = -1
131

Yineng Zhang's avatar
Yineng Zhang committed
132
133
134
            if self.model_config.hf_config.architectures == [
                "Qwen2VLForConditionalGeneration"
            ]:
135
                # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
136
137
138
139
                logger.info(
                    "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
                )
                server_args.chunked_prefill_size = -1
140
                server_args.disable_radix_cache = True
141

142
143
144
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
145
        if server_args.disable_outlines_disk_cache:
146
147
            from outlines.caching import disable_cache

148
149
            disable_cache()

150
151
        global_server_args_dict.update(
            {
152
153
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
154
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
155
                "disable_mla": server_args.disable_mla,
156
                "torchao_config": server_args.torchao_config,
157
                "enable_nan_detection": server_args.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
158
                "enable_dp_attention": server_args.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
159
                "enable_ep_moe": server_args.enable_ep_moe,
160
161
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
162

163
164
        set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

165
        # Get memory before model loading
166
        min_per_gpu_memory = self.init_torch_distributed()
167
168

        # Load the model
169
        self.sampler = Sampler()
170
        self.load_model()
171

172
173
174
175
176
        # Apply torchao quantization
        apply_torchao_config_to_model(
            self.model, global_server_args_dict["torchao_config"]
        )

177
        # Apply torch TP if the model supports it
178
179
180
181
182
183
184
        supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
        if self.tp_size > 1 and supports_torch_tp:
            self.apply_torch_tp()
            self.torch_tp_applied = True
        else:
            self.torch_tp_applied = False

185
        # Init memory pool and attention backends
186
187
        if server_args.lora_paths is not None:
            self.init_lora_manager()
188
189
        self.init_memory_pool(
            min_per_gpu_memory,
190
            server_args.max_running_requests,
191
192
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
193
194
195
196
197
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
198
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
199
            self.init_attention_backend()
200
201

    def init_torch_distributed(self):
202
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
203
        # Init torch distributed
204
        torch.get_device_module(self.device).set_device(self.gpu_id)
Zhang, Liangang's avatar
Zhang, Liangang committed
205
206
        if self.device == "cuda":
            backend = "nccl"
207
        elif self.device == "xpu":
208
            # TODO(liangan1): Just use gloo to bypass the initilization fail
209
            # Need to use xccl for xpu backend in the future
210
            backend = "gloo"
211
212
        elif self.device == "hpu":
            backend = "hccl"
213

214
        if not self.server_args.enable_p2p_check:
215
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
216
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
217
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
218
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
219
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
220
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
221
222
223
224
225
226
227
228
229
230
231
232

        if not self.is_draft_worker:
            # Only initilzie the distributed environment on the target model worker.
            init_distributed_environment(
                backend=backend,
                world_size=self.tp_size,
                rank=self.tp_rank,
                local_rank=self.gpu_id,
                distributed_init_method=dist_init_method,
            )
            initialize_model_parallel(tensor_model_parallel_size=self.tp_size)

233
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
234
            self.device, self.gpu_id, distributed=self.tp_size > 1
235
        )
236
        self.tp_group = get_tp_group()
237

238
        # Check memory for tensor parallelism
239
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
240
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
241
            if min_per_gpu_memory < local_gpu_memory * 0.9:
242
243
244
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
245

246
        return min_per_gpu_memory
247

Lianmin Zheng's avatar
Lianmin Zheng committed
248
    def load_model(self):
249
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
250
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
251
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
252
253
254

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
255
256
257
258
259
260
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
261
                self.model_config.dtype = torch.float16
Zhang, Liangang's avatar
Zhang, Liangang committed
262
263
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
264

265
        # Prepare the model config
266
267
268
269
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
        )
270
271
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()
272
273

        # Load the model
274
275
276
277
278
        self.model = get_model(
            model_config=self.model_config,
            load_config=self.load_config,
            device_config=DeviceConfig(self.device),
        )
279

280
        # Parse other args
281
        self.sliding_window_size = (
282
283
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
284
285
            else None
        )
286
        self.dtype = self.model_config.dtype
287

288
        logger.info(
289
            f"Load weight end. "
290
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
291
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
292
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
293
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
294

295
296
297
298
    def update_weights_from_disk(
        self, model_path: str, load_format: str
    ) -> tuple[bool, str]:
        """Update engine weights in-place from the disk."""
299
        from sglang.srt.model_loader.loader import (
300
301
302
303
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
304
        from sglang.srt.model_loader.utils import set_default_torch_dtype
305
306

        logger.info(
Chayenne's avatar
Chayenne committed
307
            f"Update engine weights online from disk begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
308
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
309
310
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
311
        target_device = torch.device(self.device)
312
        self.model_config.model_path = model_path
313
314
315
316
317
        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
318
319
            message = f"Failed to get model loader: {loader}."
            return False, message
320
321
322

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
323
                DefaultModelLoader.Source(
324
                    config.model_path,
325
326
327
328
329
                    revision=config.revision,
                    fall_back_to_pt=getattr(
                        self.model, "fall_back_to_pt_during_load", True
                    ),
                )
330
331
332
333
334
335
336
337
338
339
340
341
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

342
        with set_default_torch_dtype(self.model_config.dtype):
343
            try:
344
                iter = get_weight_iter(self.model_config)
345
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
346
                message = f"Failed to get weights iterator: {e}."
347
348
349
350
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
351
352
353
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
354
355
                del iter
                gc.collect()
356
                iter = get_weight_iter(self.model_config)
357
358
359
360
361
362
363
364
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.load_config = load_config

365
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
366
        return True, "Succeeded to update model weights."
367

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    def init_weights_update_group(
        self,
        master_address,
        master_port,
        rank_offset,
        world_size,
        group_name,
        backend="nccl",
    ):
        """Initialize the Torch process group for model parameter updates.

        `_model_update_group` is used in the RLHF workflow, where rank
        0 is the actor model in the training engine, and the other ranks are
        the inference engine, which is used for rollout.

        In the RLHF workflow, the training engine updates the model
        weights/parameters online, and broadcasts them to the inference
        engine through the `_model_update_group` process group.
        """
        assert (
            torch.distributed.is_initialized()
        ), "Default torch process group must be initialized"
        assert group_name != "", "Group name cannot be empty"

        rank = rank_offset + self.tp_rank

        logger.info(
            f"init custom process group: master_address={master_address}, master_port={master_port}, "
            f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
        )

        try:
            self._model_update_group = init_custom_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=rank,
                group_name=group_name,
            )
            dist.barrier(group=self._model_update_group, device_ids=[rank])
            return True, "Succeeded to initialize custom process group."
        except Exception as e:
            message = f"Failed to initialize custom process group: {e}."
            logger.error(message)
            return False, message

    def update_weights_from_distributed(self, name, dtype, shape):
        """
        Update specific parameter in the model weights online
        through `_model_update_group` process group.

        Args:
            name: the name of the parameter to be updated.
            dtype: the data type of the parameter to be updated.
            shape: the shape of the parameter to be updated.
        """
        target_dtype = (
            dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
        )

        assert (
            self._model_update_group is not None
        ), "model update group must be initialized"

        try:
            weights = torch.empty(shape, dtype=target_dtype, device=self.device)
            torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
            self.model.load_weights([(name, weights)])
            return True, f"Succeeded to update parameter {name} online."

        except Exception as e:
            error_msg = (
                f"Failed to update parameter online: {e}. "
                f"The full weights of the ModelRunner are partially updated. "
                f"Please discard the whole weights."
            )
            logger.error(error_msg)
            return False, error_msg

447
448
449
    def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
        self.model.load_weights(named_tensors)
        return True, "Success"
450

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        # TODO: (chenyang) Add support for Qwen models.
        try:
            return self.model.get_weights_by_name(
                name, truncate_size, tp_size=self.tp_size
            )
        except Exception as e:
            logger.error(f"Error when getting parameter {name}: {e}")
            return None

468
469
470
471
472
473
474
475
476
477
478
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

479
    def profile_max_num_token(self, total_gpu_memory: int):
480
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
481
            self.device, self.gpu_id, distributed=self.tp_size > 1
482
        )
483
484
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
485
            and not self.server_args.disable_mla
486
487
488
489
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
490
                * torch._utils._element_size(self.kv_cache_dtype)
491
492
493
494
495
496
497
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
498
                * torch._utils._element_size(self.kv_cache_dtype)
499
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
500
501
502
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
503
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
504
505
        return max_num_token

506
    def init_memory_pool(
507
508
        self,
        total_gpu_memory: int,
509
510
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
511
    ):
512
513
514
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
HAI's avatar
HAI committed
515
516
517
518
            if is_hip():  # Using natively supported format
                self.kv_cache_dtype = torch.float8_e5m2fnuz
            else:
                self.kv_cache_dtype = torch.float8_e5m2
519
520
521
522
523
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

524
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546

        if max_num_reqs is None:
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                4096,
            )

        if not self.spec_algorithm.is_none():
            if self.is_draft_worker:
                self.max_total_num_tokens = self.server_args.draft_runner_cache_size
            else:
                self.server_args.draft_runner_cache_size = (
                    self.max_total_num_tokens
                    + max_num_reqs * self.server_args.speculative_num_steps
                    + 100
                )

547
548
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
549
                logging.warning(
550
551
552
553
554
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
555

556
        if self.max_total_num_tokens <= 0:
557
            raise RuntimeError(
558
                "Not enough memory. Please try to increase --mem-fraction-static."
559
            )
560

Liangsheng Yin's avatar
Liangsheng Yin committed
561
        self.req_to_token_pool = ReqToTokenPool(
562
563
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
564
            device=self.device,
565
            use_records=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
566
        )
567
568
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
569
            and not self.server_args.disable_mla
570
571
572
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
573
                dtype=self.kv_cache_dtype,
574
575
576
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
577
                device=self.device,
578
            )
Shuo Yang's avatar
Shuo Yang committed
579
580
581
582
583
584
585
586
587
588
        elif self.server_args.enable_double_sparsity:
            self.token_to_kv_pool = DoubleSparseTokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.kv_cache_dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
                device=self.device,
                heavy_channel_num=self.server_args.ds_heavy_channel_num,
            )
589
590
591
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
592
                dtype=self.kv_cache_dtype,
593
594
595
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
596
                device=self.device,
597
            )
598
        logger.info(
599
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
600
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
601
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
602

Lianmin Zheng's avatar
Lianmin Zheng committed
603
604
605
606
607
608
609
610
611
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

612
613
614
615
616
617
618
619
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
620
            )
621
            assert not self.model_config.is_encoder_decoder, (
622
623
624
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
Shuo Yang's avatar
Shuo Yang committed
625
626
627
628
            if self.server_args.enable_double_sparsity:
                self.attn_backend = DoubleSparseAttnBackend(self)
            else:
                self.attn_backend = TritonAttnBackend(self)
629
630
        elif self.server_args.attention_backend == "torch_native":
            self.attn_backend = TorchNativeAttnBackend(self)
631
        else:
632
633
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
634
            )
635

Shuo Yang's avatar
Shuo Yang committed
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
    def init_double_sparsity_channel_config(self, selected_channel):
        selected_channel = "." + selected_channel + "_proj"
        self.sorted_channels = []
        # load channel config
        with open(self.server_args.ds_channel_config_path, "r") as f:
            channel_config = json.load(f)

        for i in range(self.model_config.num_hidden_layers):
            key = "model.layers." + str(i) + ".self_attn" + selected_channel
            self.sorted_channels.append(
                torch.tensor(channel_config[key])[
                    :, : self.server_args.ds_heavy_channel_num
                ]
                .contiguous()
                .cuda()
            )

653
    def init_cuda_graphs(self):
654
        """Capture cuda graphs."""
655
656
657
658
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

659
660
661
662
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

663
664
        if self.server_args.disable_cuda_graph:
            return
665

666
        tic = time.time()
667
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
668
        self.cuda_graph_runner = CudaGraphRunner(self)
669
        logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
670

671
672
673
674
675
676
677
    def apply_torch_tp(self):
        logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
        from sglang.srt.model_parallel import tensor_parallel

        device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
        tensor_parallel(self.model, device_mesh)

678
    def forward_decode(self, forward_batch: ForwardBatch):
679
        self.attn_backend.init_forward_metadata(forward_batch)
680
        return self.model.forward(
681
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
682
683
        )

684
    def forward_extend(self, forward_batch: ForwardBatch):
685
        self.attn_backend.init_forward_metadata(forward_batch)
686
        if self.is_generation:
Rin Intachuen's avatar
Rin Intachuen committed
687
688
689
690
691
692
693
694
695
696
697
            if forward_batch.input_embeds is None:
                return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
                )
            else:
                return self.model.forward(
                    forward_batch.input_ids,
                    forward_batch.positions,
                    forward_batch,
                    input_embeds=forward_batch.input_embeds.bfloat16(),
                )
698
699
700
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
701
702
703
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
704
705
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
706

Ke Bao's avatar
Ke Bao committed
707
708
709
710
711
    def forward_idle(self, forward_batch: ForwardBatch):
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

712
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
713
714
715
716
717
718
719
        if (
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
        ):
            return self.cuda_graph_runner.replay(forward_batch)

720
721
722
723
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Ke Bao's avatar
Ke Bao committed
724
725
        elif forward_batch.forward_mode.is_idle():
            return self.forward_idle(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
726
        else:
727
            raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
728

729
730
731
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
732
        # Apply logit bias
733
        sampling_info = forward_batch.sampling_info
734
735
736
737
738
739
740
741
742
        if sampling_info.sampling_info_done:
            # Overlap mode: the function update_regex_vocab_mask was executed
            # in process_batch_result of the last batch.
            if sampling_info.grammars:
                sampling_info.sampling_info_done.wait()
        else:
            # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
            sampling_info.update_regex_vocab_mask()
            sampling_info.update_penalties()
743
744
745
746
747
748
749
750
751
        sampling_info.apply_logits_bias(logits_output.next_token_logits)

        # Sample the next tokens
        next_token_ids = self.sampler(
            logits_output,
            sampling_info,
            forward_batch.return_logprob,
            forward_batch.top_logprobs_nums,
        )
752
753
        return next_token_ids

Yineng Zhang's avatar
Yineng Zhang committed
754
755
756
757
758
759
760
761
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"