model_runner.py 17.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

Cody Yu's avatar
Cody Yu committed
18
import importlib
19
20
21
import importlib.resources
import logging
import pkgutil
22
import warnings
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
24
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.nn as nn
28
29
30
31
32
33
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
    BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
35
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
36
37
38
39
40
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
)
41
from vllm.distributed.parallel_state import in_the_same_node_as
42
from vllm.model_executor.model_loader import get_model
43
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
44

45
from sglang.global_config import global_config
46
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
47
48
49
50
51
52
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
from sglang.srt.model_config import AttentionArch
53
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
Lianmin Zheng's avatar
Lianmin Zheng committed
54
from sglang.srt.server_args import ServerArgs
55
56
from sglang.srt.utils import (
    get_available_gpu_memory,
57
    is_generation_model,
Lianmin Zheng's avatar
Lianmin Zheng committed
58
    is_llama3_405b_fp8_head_16,
59
    is_multimodal_model,
60
    monkey_patch_vllm_dummy_weight_loader,
61
    monkey_patch_vllm_p2p_access_check,
62
    monkey_patch_vllm_qvk_linear_loader,
63
)
64

Ying Sheng's avatar
Ying Sheng committed
65
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
66

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69
70
71

class ModelRunner:
    def __init__(
        self,
        model_config,
72
73
74
75
76
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
    ):
79
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
82
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
83
84
85
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
86
        self.server_args = server_args
87
        self.is_multimodal_model = is_multimodal_model(self.model_config)
88
89
90
91
92
        global_server_args_dict.update(
            {
                "disable_flashinfer": server_args.disable_flashinfer,
                "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
                "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
93
                "enable_mla": server_args.enable_mla,
94
95
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
96
97

        # Init torch distributed
98
        torch.cuda.set_device(self.gpu_id)
Ying Sheng's avatar
Ying Sheng committed
99
        logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
100
101
102
103

        if not server_args.enable_p2p_check:
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

104
105
106
107
        if server_args.nccl_init_addr:
            nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
108
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
109
110
111
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
112
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
113
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
114
115
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
116
117
118
        total_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
119
120
121
122
        self.tp_group = get_tp_group()
        self.is_multi_node_tp = not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        )
123
124

        if self.tp_size > 1:
125
            total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
126
            if total_local_gpu_memory < total_gpu_memory * 0.9:
127
128
129
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
130

131
        # Load the model and create memory pool
Lianmin Zheng's avatar
Lianmin Zheng committed
132
        self.load_model()
133
134
135
136
137
        self.init_memory_pool(
            total_gpu_memory,
            server_args.max_num_reqs,
            server_args.max_total_tokens,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
138
        self.init_cublas()
139
        self.init_flashinfer()
Lianmin Zheng's avatar
Lianmin Zheng committed
140

141
142
143
144
        if self.is_generation:
            # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
            # Capture cuda graphs
            self.init_cuda_graphs()
145

Lianmin Zheng's avatar
Lianmin Zheng committed
146
    def load_model(self):
147
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
148
            f"[gpu={self.gpu_id}] Load weight begin. "
149
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
150
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
151

152
        monkey_patch_vllm_dummy_weight_loader()
153
        device_config = DeviceConfig()
Lianmin Zheng's avatar
Lianmin Zheng committed
154
        load_config = LoadConfig(load_format=self.server_args.load_format)
155
        vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
156
157
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
158
159
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
160
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
161
            dtype=self.server_args.dtype,
162
163
164
            seed=42,
            skip_tokenizer_init=True,
        )
165

Lianmin Zheng's avatar
Lianmin Zheng committed
166
        if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
167
168
169
170
171
            # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
            self.model_config.hf_config.num_key_value_heads = 8
            vllm_model_config.hf_config.num_key_value_heads = 8
            monkey_patch_vllm_qvk_linear_loader()

Lianmin Zheng's avatar
Lianmin Zheng committed
172
        self.dtype = vllm_model_config.dtype
173
174
175
176
177
178
179
180
        if self.model_config.model_overide_args is not None:
            vllm_model_config.hf_config.update(self.model_config.model_overide_args)

        self.model = get_model(
            model_config=vllm_model_config,
            device_config=device_config,
            load_config=load_config,
            lora_config=None,
181
            multimodal_config=None,
182
183
            parallel_config=None,
            scheduler_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
184
            cache_config=None,
185
        )
186
187
188
189
190
        self.sliding_window_size = (
            self.model.get_window_size()
            if hasattr(self.model, "get_window_size")
            else None
        )
191
192
193
194
        self.is_generation = is_generation_model(
            self.model_config.hf_config.architectures
        )

195
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
196
            f"[gpu={self.gpu_id}] Load weight end. "
197
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
198
            f"dtype={self.dtype}, "
199
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
200
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
201

Lianmin Zheng's avatar
Lianmin Zheng committed
202
    def profile_max_num_token(self, total_gpu_memory):
203
204
205
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
                * torch._utils._element_size(self.dtype)
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
                * torch._utils._element_size(self.dtype)
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
223
224
225
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
226
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
227
228
        return max_num_token

229
230
231
    def init_memory_pool(
        self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
    ):
232
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
233
234
235
236
237
238
239
240
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
                warnings.warn(
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
241

242
        if self.max_total_num_tokens <= 0:
243
            raise RuntimeError(
244
                "Not enough memory. Please try to increase --mem-fraction-static."
245
            )
246

Liangsheng Yin's avatar
Liangsheng Yin committed
247
        if max_num_reqs is None:
248
249
250
251
252
253
254
255
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                5120,
Liangsheng Yin's avatar
Liangsheng Yin committed
256
257
258
259
            )

        self.req_to_token_pool = ReqToTokenPool(
            max_num_reqs,
Lianmin Zheng's avatar
Lianmin Zheng committed
260
261
            self.model_config.context_len + 8,
        )
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.dtype,
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
            logger.info("using MLA Triton implementaion, flashinfer is disabled")
            # FIXME: temporarily only Triton MLA is supported
            self.server_args.disable_flashinfer = True
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
284
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
285
            f"[gpu={self.gpu_id}] Memory pool end. "
286
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
287
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
288

Lianmin Zheng's avatar
Lianmin Zheng committed
289
290
291
292
293
294
295
296
297
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

298
    def init_flashinfer(self):
299
        if self.server_args.disable_flashinfer:
300
301
302
            assert (
                self.sliding_window_size is None
            ), "turn on flashinfer to support window attention"
303
304
            self.flashinfer_prefill_wrapper_ragged = None
            self.flashinfer_prefill_wrapper_paged = None
305
            self.flashinfer_decode_wrapper = None
306
307
308
309
310
311
312
313
314
315
            return

        if not _grouped_size_compiled_for_decode_kernels(
            self.model_config.num_attention_heads // self.tp_size,
            self.model_config.get_num_kv_heads(self.tp_size),
        ):
            use_tensor_cores = True
        else:
            use_tensor_cores = False

316
        if self.sliding_window_size is None:
317
            self.flashinfer_workspace_buffer = torch.empty(
318
319
320
321
322
323
                global_config.flashinfer_workspace_size,
                dtype=torch.uint8,
                device="cuda",
            )
            self.flashinfer_prefill_wrapper_ragged = (
                BatchPrefillWithRaggedKVCacheWrapper(
324
                    self.flashinfer_workspace_buffer, "NHD"
325
326
327
                )
            )
            self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
328
                self.flashinfer_workspace_buffer, "NHD"
329
330
            )
            self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
331
                self.flashinfer_workspace_buffer,
332
333
334
335
                "NHD",
                use_tensor_cores=use_tensor_cores,
            )
        else:
336
            self.flashinfer_workspace_buffer = torch.empty(
337
338
339
340
                global_config.flashinfer_workspace_size,
                dtype=torch.uint8,
                device="cuda",
            )
341
            self.flashinfer_prefill_wrapper_ragged = None
342
343
344
345
346
            self.flashinfer_prefill_wrapper_paged = []
            self.flashinfer_decode_wrapper = []
            for i in range(2):
                self.flashinfer_prefill_wrapper_paged.append(
                    BatchPrefillWithPagedKVCacheWrapper(
347
                        self.flashinfer_workspace_buffer, "NHD"
348
349
350
351
                    )
                )
                self.flashinfer_decode_wrapper.append(
                    BatchDecodeWithPagedKVCacheWrapper(
352
                        self.flashinfer_workspace_buffer,
353
354
355
356
                        "NHD",
                        use_tensor_cores=use_tensor_cores,
                    )
                )
357

358
    def init_cuda_graphs(self):
359
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
360
361
362
363
364

        if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
            self.cuda_graph_runner = None
            return

365
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
366
            f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
367
        )
368
        batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
369
        self.cuda_graph_runner = CudaGraphRunner(
370
371
372
            self,
            max_batch_size_to_capture=max(batch_size_list),
            use_torch_compile=self.server_args.enable_torch_compile,
373
        )
374
375
        try:
            self.cuda_graph_runner.capture(batch_size_list)
376
        except RuntimeError as e:
377
            raise Exception(
378
379
380
381
382
383
                f"Capture cuda graph failed: {e}\n"
                "Possible solutions:\n"
                "1. disable torch compile by not using --enable-torch-compile\n"
                "2. disable cuda graph by --disable-cuda-graph\n"
                "3. set --mem-fraction-static to a smaller value\n"
                "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
384
            )
385

Lianmin Zheng's avatar
Lianmin Zheng committed
386
    @torch.inference_mode()
387
    def forward_decode(self, batch: ScheduleBatch):
388
389
390
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
            return self.cuda_graph_runner.replay(batch)

391
        input_metadata = InputMetadata.from_schedule_batch(
392
393
394
            self,
            batch,
            ForwardMode.DECODE,
395
        )
396

397
398
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
399
400
401
        )

    @torch.inference_mode()
402
    def forward_extend(self, batch: ScheduleBatch):
403
        input_metadata = InputMetadata.from_schedule_batch(
404
405
406
            self,
            batch,
            forward_mode=ForwardMode.EXTEND,
407
408
409
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
410
411
412
        )

    @torch.inference_mode()
413
    def forward_extend_multi_modal(self, batch: ScheduleBatch):
414
        input_metadata = InputMetadata.from_schedule_batch(
415
416
417
            self,
            batch,
            forward_mode=ForwardMode.EXTEND,
Lianmin Zheng's avatar
Lianmin Zheng committed
418
419
        )
        return self.model.forward(
420
            batch.input_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
421
422
            input_metadata.positions,
            input_metadata,
423
424
425
            input_metadata.pixel_values,
            input_metadata.image_sizes,
            input_metadata.image_offsets,
Lianmin Zheng's avatar
Lianmin Zheng committed
426
427
        )

428
    def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
Lianmin Zheng's avatar
Lianmin Zheng committed
429
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
430
431
432
            return self.forward_extend_multi_modal(batch)
        elif forward_mode == ForwardMode.DECODE:
            return self.forward_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
433
        elif forward_mode == ForwardMode.EXTEND:
434
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")
437
438
439
440
441
442
443
444
445
446
447


@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
            module = importlib.import_module(name)
            if hasattr(module, "EntryClass"):
448
                entry = module.EntryClass
449
450
451
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
452
                    for tmp in entry:
453
                        assert tmp.__name__ not in model_arch_name_to_cls
454
                        model_arch_name_to_cls[tmp.__name__] = tmp
455
                else:
456
                    assert entry.__name__ not in model_arch_name_to_cls
457
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
458
459
460

            # compat: some models such as chatglm has incorrect class set in config.json
            # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
461
462
463
            if hasattr(module, "EntryClassRemapping") and isinstance(
                module.EntryClassRemapping, list
            ):
Qubitium's avatar
Qubitium committed
464
465
                for remap in module.EntryClassRemapping:
                    if isinstance(remap, tuple) and len(remap) == 2:
466
                        assert remap[0] not in model_arch_name_to_cls
Qubitium's avatar
Qubitium committed
467
468
                        model_arch_name_to_cls[remap[0]] = remap[1]

469
470
471
472
473
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
474

475
476
477
478
479
480
481
482
483
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
484
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)