model_runner.py 17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

Cody Yu's avatar
Cody Yu committed
18
import importlib
19
20
21
import importlib.resources
import logging
import pkgutil
22
import warnings
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
24
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.nn as nn
28
29
30
31
32
33
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
    BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
35
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
36
37
38
39
40
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
)
41
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
42

43
from sglang.global_config import global_config
44
from sglang.srt.managers.schedule_batch import (
45
46
47
48
49
    Batch,
    ForwardMode,
    InputMetadata,
    global_server_args_dict,
)
50
51
52
53
54
55
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
from sglang.srt.model_config import AttentionArch
Lianmin Zheng's avatar
Lianmin Zheng committed
56
from sglang.srt.server_args import ServerArgs
57
58
from sglang.srt.utils import (
    get_available_gpu_memory,
59
    is_llama3_405b_fp8,
60
    is_multimodal_model,
61
    monkey_patch_vllm_dummy_weight_loader,
62
    monkey_patch_vllm_p2p_access_check,
63
    monkey_patch_vllm_qvk_linear_loader,
64
)
65

Ying Sheng's avatar
Ying Sheng committed
66
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
67

Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
70
71
72

class ModelRunner:
    def __init__(
        self,
        model_config,
73
74
75
76
77
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    ):
80
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
83
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
86
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
87
        self.server_args = server_args
88
        self.is_multimodal_model = is_multimodal_model(self.model_config)
89
90
91
92
93
        global_server_args_dict.update(
            {
                "disable_flashinfer": server_args.disable_flashinfer,
                "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
                "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
94
                "enable_mla": server_args.enable_mla,
95
96
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
97
98

        # Init torch distributed
99
        torch.cuda.set_device(self.gpu_id)
Ying Sheng's avatar
Ying Sheng committed
100
        logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
101
102
103
104

        if not server_args.enable_p2p_check:
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

105
106
107
108
        if server_args.nccl_init_addr:
            nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
109
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
110
111
112
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
113
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
114
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
115
116
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
117
        self.tp_group = get_tp_group()
118
119
120
        total_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
121
122

        if self.tp_size > 1:
123
            total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
124
            if total_local_gpu_memory < total_gpu_memory * 0.9:
125
126
127
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
128

129
        # Load the model and create memory pool
Lianmin Zheng's avatar
Lianmin Zheng committed
130
        self.load_model()
131
132
133
134
135
        self.init_memory_pool(
            total_gpu_memory,
            server_args.max_num_reqs,
            server_args.max_total_tokens,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
136
        self.init_cublas()
137
        self.init_flash_infer()
Lianmin Zheng's avatar
Lianmin Zheng committed
138

139
140
141
        # Capture cuda graphs
        self.init_cuda_graphs()

Lianmin Zheng's avatar
Lianmin Zheng committed
142
    def load_model(self):
143
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
144
            f"[gpu={self.gpu_id}] Load weight begin. "
145
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
146
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
147

148
        monkey_patch_vllm_dummy_weight_loader()
149
        device_config = DeviceConfig()
Lianmin Zheng's avatar
Lianmin Zheng committed
150
        load_config = LoadConfig(load_format=self.server_args.load_format)
151
        vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
152
153
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
154
155
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
156
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
157
            dtype=self.server_args.dtype,
158
159
160
            seed=42,
            skip_tokenizer_init=True,
        )
161

162
        if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
163
164
165
166
167
            # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
            self.model_config.hf_config.num_key_value_heads = 8
            vllm_model_config.hf_config.num_key_value_heads = 8
            monkey_patch_vllm_qvk_linear_loader()

Lianmin Zheng's avatar
Lianmin Zheng committed
168
        self.dtype = vllm_model_config.dtype
169
170
171
        if self.model_config.model_overide_args is not None:
            vllm_model_config.hf_config.update(self.model_config.model_overide_args)

172
173
174
175
176
177
178
179
180
        if (
            self.server_args.efficient_weight_load
            and "llama" in self.server_args.model_path.lower()
            and self.server_args.quantization == "fp8"
        ):
            from sglang.srt.model_loader.model_loader import get_model
        else:
            from vllm.model_executor.model_loader import get_model

181
182
183
184
185
        self.model = get_model(
            model_config=vllm_model_config,
            device_config=device_config,
            load_config=load_config,
            lora_config=None,
186
            multimodal_config=None,
187
188
            parallel_config=None,
            scheduler_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
189
            cache_config=None,
190
        )
191
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
192
            f"[gpu={self.gpu_id}] Load weight end. "
193
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
194
            f"dtype={self.dtype}, "
195
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
196
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
197

Lianmin Zheng's avatar
Lianmin Zheng committed
198
    def profile_max_num_token(self, total_gpu_memory):
199
200
201
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
                * torch._utils._element_size(self.dtype)
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
                * torch._utils._element_size(self.dtype)
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
219
220
221
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
222
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
223
224
        return max_num_token

225
226
227
    def init_memory_pool(
        self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
    ):
228
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
229
230
231
232
233
234
235
236
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
                warnings.warn(
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
237

238
        if self.max_total_num_tokens <= 0:
239
            raise RuntimeError(
240
                "Not enough memory. Please try to increase --mem-fraction-static."
241
            )
242

Liangsheng Yin's avatar
Liangsheng Yin committed
243
        if max_num_reqs is None:
244
245
246
247
248
249
250
251
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                5120,
Liangsheng Yin's avatar
Liangsheng Yin committed
252
253
254
255
            )

        self.req_to_token_pool = ReqToTokenPool(
            max_num_reqs,
Lianmin Zheng's avatar
Lianmin Zheng committed
256
257
            self.model_config.context_len + 8,
        )
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.dtype,
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
            logger.info("using MLA Triton implementaion, flashinfer is disabled")
            # FIXME: temporarily only Triton MLA is supported
            self.server_args.disable_flashinfer = True
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
280
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
281
            f"[gpu={self.gpu_id}] Memory pool end. "
282
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
283
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
284

Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
288
289
290
291
292
293
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

294
    def init_flash_infer(self):
295
296
297
        if self.server_args.disable_flashinfer:
            self.flashinfer_prefill_wrapper_ragged = None
            self.flashinfer_prefill_wrapper_paged = None
298
            self.flashinfer_decode_wrapper = None
299
300
301
302
303
304
305
306
307
308
            return

        if not _grouped_size_compiled_for_decode_kernels(
            self.model_config.num_attention_heads // self.tp_size,
            self.model_config.get_num_kv_heads(self.tp_size),
        ):
            use_tensor_cores = True
        else:
            use_tensor_cores = False

309
310
        self.flashinfer_workspace_buffers = torch.empty(
            2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
311
312
        )
        self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
313
            self.flashinfer_workspace_buffers[0], "NHD"
314
315
        )
        self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
316
            self.flashinfer_workspace_buffers[1], "NHD"
317
318
        )
        self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
319
320
321
            self.flashinfer_workspace_buffers[0],
            "NHD",
            use_tensor_cores=use_tensor_cores,
322
        )
323

324
    def init_cuda_graphs(self):
325
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
326
327
328
329
330

        if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
            self.cuda_graph_runner = None
            return

331
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
332
            f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
333
        )
334
        batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
335
        self.cuda_graph_runner = CudaGraphRunner(
336
337
338
            self,
            max_batch_size_to_capture=max(batch_size_list),
            use_torch_compile=self.server_args.enable_torch_compile,
339
        )
340
341
        try:
            self.cuda_graph_runner.capture(batch_size_list)
342
        except RuntimeError as e:
343
            raise Exception(
344
345
346
347
348
349
                f"Capture cuda graph failed: {e}\n"
                "Possible solutions:\n"
                "1. disable torch compile by not using --enable-torch-compile\n"
                "2. disable cuda graph by --disable-cuda-graph\n"
                "3. set --mem-fraction-static to a smaller value\n"
                "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
350
            )
351

Lianmin Zheng's avatar
Lianmin Zheng committed
352
    @torch.inference_mode()
353
354
355
356
    def forward_decode(self, batch: Batch):
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
            return self.cuda_graph_runner.replay(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
        input_metadata = InputMetadata.create(
            self,
359
            forward_mode=ForwardMode.DECODE,
360
361
362
363
364
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
Liangsheng Yin's avatar
Liangsheng Yin committed
365
            top_logprobs_nums=batch.top_logprobs_nums,
366
367
368
369
            return_logprob=batch.return_logprob,
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
370
371
372
        )

    @torch.inference_mode()
373
    def forward_extend(self, batch: Batch):
Lianmin Zheng's avatar
Lianmin Zheng committed
374
375
        input_metadata = InputMetadata.create(
            self,
376
            forward_mode=ForwardMode.EXTEND,
377
378
379
380
381
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
Liangsheng Yin's avatar
Liangsheng Yin committed
382
            top_logprobs_nums=batch.top_logprobs_nums,
383
384
385
386
            return_logprob=batch.return_logprob,
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388
389
        )

    @torch.inference_mode()
390
    def forward_extend_multi_modal(self, batch: Batch):
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
393
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
394
395
396
397
398
399
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
            return_logprob=batch.return_logprob,
400
            top_logprobs_nums=batch.top_logprobs_nums,
Lianmin Zheng's avatar
Lianmin Zheng committed
401
402
        )
        return self.model.forward(
403
            batch.input_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
            input_metadata.positions,
            input_metadata,
406
407
408
            batch.pixel_values,
            batch.image_sizes,
            batch.image_offsets,
Lianmin Zheng's avatar
Lianmin Zheng committed
409
410
        )

411
    def forward(self, batch: Batch, forward_mode: ForwardMode):
Lianmin Zheng's avatar
Lianmin Zheng committed
412
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
413
414
415
            return self.forward_extend_multi_modal(batch)
        elif forward_mode == ForwardMode.DECODE:
            return self.forward_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
416
        elif forward_mode == ForwardMode.EXTEND:
417
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
418
419
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")
420
421
422
423
424
425
426
427
428
429
430


@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
            module = importlib.import_module(name)
            if hasattr(module, "EntryClass"):
431
                entry = module.EntryClass
432
433
434
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
435
436
                    for tmp in entry:
                        model_arch_name_to_cls[tmp.__name__] = tmp
437
438
                else:
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
439
440
441

            # compat: some models such as chatglm has incorrect class set in config.json
            # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
442
443
444
            if hasattr(module, "EntryClassRemapping") and isinstance(
                module.EntryClassRemapping, list
            ):
Qubitium's avatar
Qubitium committed
445
446
447
448
                for remap in module.EntryClassRemapping:
                    if isinstance(remap, tuple) and len(remap) == 2:
                        model_arch_name_to_cls[remap[0]] = remap[1]

449
450
451
452
453
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
454

455
456
457
458
459
460
461
462
463
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
464
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)