model_runner.py 18.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

Cody Yu's avatar
Cody Yu committed
18
import importlib
19
20
21
import importlib.resources
import logging
import pkgutil
22
import warnings
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
24
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.nn as nn
28
29
30
31
32
33
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
    BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
35
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
36
37
38
39
40
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
)
41
from vllm.distributed.parallel_state import in_the_same_node_as
42
from vllm.model_executor.model_loader import get_model
43
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
44

45
from sglang.global_config import global_config
46
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
47
48
49
50
51
52
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
from sglang.srt.model_config import AttentionArch
53
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
Lianmin Zheng's avatar
Lianmin Zheng committed
54
from sglang.srt.server_args import ServerArgs
55
56
from sglang.srt.utils import (
    get_available_gpu_memory,
57
    is_generation_model,
Lianmin Zheng's avatar
Lianmin Zheng committed
58
    is_llama3_405b_fp8_head_16,
59
    is_multimodal_model,
60
    monkey_patch_vllm_dummy_weight_loader,
61
    monkey_patch_vllm_p2p_access_check,
62
    monkey_patch_vllm_qvk_linear_loader,
63
)
64

Ying Sheng's avatar
Ying Sheng committed
65
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
66

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69
70
71

class ModelRunner:
    def __init__(
        self,
        model_config,
72
73
74
75
76
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
    ):
79
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
82
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
83
84
85
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
86
        self.server_args = server_args
87
        self.is_multimodal_model = is_multimodal_model(self.model_config)
88
89
90
91
92
        global_server_args_dict.update(
            {
                "disable_flashinfer": server_args.disable_flashinfer,
                "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
                "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
93
                "enable_mla": server_args.enable_mla,
94
95
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
96
97

        # Init torch distributed
98
        torch.cuda.set_device(self.gpu_id)
Ying Sheng's avatar
Ying Sheng committed
99
        logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
100
101
102
103

        if not server_args.enable_p2p_check:
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

104
105
106
107
        if server_args.nccl_init_addr:
            nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
108
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
109
110
111
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
112
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
113
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
114
115
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
116
117
118
        total_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
119
120
121
122
        self.tp_group = get_tp_group()
        self.is_multi_node_tp = not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        )
123
124

        if self.tp_size > 1:
125
            total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
126
            if total_local_gpu_memory < total_gpu_memory * 0.9:
127
128
129
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
130

131
        # Load the model and create memory pool
Lianmin Zheng's avatar
Lianmin Zheng committed
132
        self.load_model()
133
134
135
136
137
        self.init_memory_pool(
            total_gpu_memory,
            server_args.max_num_reqs,
            server_args.max_total_tokens,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
138
        self.init_cublas()
139
        self.init_flashinfer()
Lianmin Zheng's avatar
Lianmin Zheng committed
140

141
142
143
144
        if self.is_generation:
            # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
            # Capture cuda graphs
            self.init_cuda_graphs()
145

Lianmin Zheng's avatar
Lianmin Zheng committed
146
    def load_model(self):
147
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
148
            f"[gpu={self.gpu_id}] Load weight begin. "
149
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
150
        )
151
152
153
154
155
        if torch.cuda.get_device_capability()[0] < 8:
            logger.info(
                "Compute capability below sm80 use float16 due to lack of bfloat16 support."
            )
            self.server_args.dtype = "float16"
Lianmin Zheng's avatar
Lianmin Zheng committed
156

157
        monkey_patch_vllm_dummy_weight_loader()
158
        device_config = DeviceConfig()
Lianmin Zheng's avatar
Lianmin Zheng committed
159
        load_config = LoadConfig(load_format=self.server_args.load_format)
160
        vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
161
162
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
163
164
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
165
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
166
            dtype=self.server_args.dtype,
167
168
169
            seed=42,
            skip_tokenizer_init=True,
        )
170

Lianmin Zheng's avatar
Lianmin Zheng committed
171
        if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
172
173
174
175
176
            # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
            self.model_config.hf_config.num_key_value_heads = 8
            vllm_model_config.hf_config.num_key_value_heads = 8
            monkey_patch_vllm_qvk_linear_loader()

Lianmin Zheng's avatar
Lianmin Zheng committed
177
        self.dtype = vllm_model_config.dtype
178
179
180
181
182
183
184
185
        if self.model_config.model_overide_args is not None:
            vllm_model_config.hf_config.update(self.model_config.model_overide_args)

        self.model = get_model(
            model_config=vllm_model_config,
            device_config=device_config,
            load_config=load_config,
            lora_config=None,
186
            multimodal_config=None,
187
188
            parallel_config=None,
            scheduler_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
189
            cache_config=None,
190
        )
191
192
193
194
195
        self.sliding_window_size = (
            self.model.get_window_size()
            if hasattr(self.model, "get_window_size")
            else None
        )
196
197
198
199
        self.is_generation = is_generation_model(
            self.model_config.hf_config.architectures
        )

200
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
201
            f"[gpu={self.gpu_id}] Load weight end. "
202
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
203
            f"dtype={self.dtype}, "
204
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
205
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
206

Lianmin Zheng's avatar
Lianmin Zheng committed
207
    def profile_max_num_token(self, total_gpu_memory):
208
209
210
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
                * torch._utils._element_size(self.dtype)
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
                * torch._utils._element_size(self.dtype)
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
228
229
230
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
231
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
        return max_num_token

234
235
236
    def init_memory_pool(
        self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
    ):
237
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
238
239
240
241
242
243
244
245
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
                warnings.warn(
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
246

247
        if self.max_total_num_tokens <= 0:
248
            raise RuntimeError(
249
                "Not enough memory. Please try to increase --mem-fraction-static."
250
            )
251

Liangsheng Yin's avatar
Liangsheng Yin committed
252
        if max_num_reqs is None:
253
254
255
256
257
258
259
260
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                5120,
Liangsheng Yin's avatar
Liangsheng Yin committed
261
262
263
264
            )

        self.req_to_token_pool = ReqToTokenPool(
            max_num_reqs,
Lianmin Zheng's avatar
Lianmin Zheng committed
265
266
            self.model_config.context_len + 8,
        )
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.dtype,
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
            logger.info("using MLA Triton implementaion, flashinfer is disabled")
            # FIXME: temporarily only Triton MLA is supported
            self.server_args.disable_flashinfer = True
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
289
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
290
            f"[gpu={self.gpu_id}] Memory pool end. "
291
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
292
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
293

Lianmin Zheng's avatar
Lianmin Zheng committed
294
295
296
297
298
299
300
301
302
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

303
    def init_flashinfer(self):
304
        if self.server_args.disable_flashinfer:
305
306
307
            assert (
                self.sliding_window_size is None
            ), "turn on flashinfer to support window attention"
308
309
            self.flashinfer_prefill_wrapper_ragged = None
            self.flashinfer_prefill_wrapper_paged = None
310
            self.flashinfer_decode_wrapper = None
311
312
313
314
315
316
317
318
319
320
            return

        if not _grouped_size_compiled_for_decode_kernels(
            self.model_config.num_attention_heads // self.tp_size,
            self.model_config.get_num_kv_heads(self.tp_size),
        ):
            use_tensor_cores = True
        else:
            use_tensor_cores = False

321
        if self.sliding_window_size is None:
322
            self.flashinfer_workspace_buffer = torch.empty(
323
324
325
326
327
328
                global_config.flashinfer_workspace_size,
                dtype=torch.uint8,
                device="cuda",
            )
            self.flashinfer_prefill_wrapper_ragged = (
                BatchPrefillWithRaggedKVCacheWrapper(
329
                    self.flashinfer_workspace_buffer, "NHD"
330
331
332
                )
            )
            self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
333
                self.flashinfer_workspace_buffer, "NHD"
334
335
            )
            self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
336
                self.flashinfer_workspace_buffer,
337
338
339
340
                "NHD",
                use_tensor_cores=use_tensor_cores,
            )
        else:
341
            self.flashinfer_workspace_buffer = torch.empty(
342
343
344
345
                global_config.flashinfer_workspace_size,
                dtype=torch.uint8,
                device="cuda",
            )
346
            self.flashinfer_prefill_wrapper_ragged = None
347
348
349
350
351
            self.flashinfer_prefill_wrapper_paged = []
            self.flashinfer_decode_wrapper = []
            for i in range(2):
                self.flashinfer_prefill_wrapper_paged.append(
                    BatchPrefillWithPagedKVCacheWrapper(
352
                        self.flashinfer_workspace_buffer, "NHD"
353
354
355
356
                    )
                )
                self.flashinfer_decode_wrapper.append(
                    BatchDecodeWithPagedKVCacheWrapper(
357
                        self.flashinfer_workspace_buffer,
358
359
360
361
                        "NHD",
                        use_tensor_cores=use_tensor_cores,
                    )
                )
362

363
    def init_cuda_graphs(self):
364
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
365
366
367
368
369

        if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
            self.cuda_graph_runner = None
            return

370
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
371
            f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
372
        )
373
        batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
374
        self.cuda_graph_runner = CudaGraphRunner(
375
376
377
            self,
            max_batch_size_to_capture=max(batch_size_list),
            use_torch_compile=self.server_args.enable_torch_compile,
378
        )
379
380
        try:
            self.cuda_graph_runner.capture(batch_size_list)
381
        except RuntimeError as e:
382
            raise Exception(
383
384
385
386
387
388
                f"Capture cuda graph failed: {e}\n"
                "Possible solutions:\n"
                "1. disable torch compile by not using --enable-torch-compile\n"
                "2. disable cuda graph by --disable-cuda-graph\n"
                "3. set --mem-fraction-static to a smaller value\n"
                "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
389
            )
390

Lianmin Zheng's avatar
Lianmin Zheng committed
391
    @torch.inference_mode()
392
    def forward_decode(self, batch: ScheduleBatch):
393
394
395
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
            return self.cuda_graph_runner.replay(batch)

396
        input_metadata = InputMetadata.from_schedule_batch(
397
398
399
            self,
            batch,
            ForwardMode.DECODE,
400
        )
401

402
403
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
406
        )

    @torch.inference_mode()
407
    def forward_extend(self, batch: ScheduleBatch):
408
        input_metadata = InputMetadata.from_schedule_batch(
409
410
411
            self,
            batch,
            forward_mode=ForwardMode.EXTEND,
412
413
414
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
415
416
417
        )

    @torch.inference_mode()
418
    def forward_extend_multi_modal(self, batch: ScheduleBatch):
419
        input_metadata = InputMetadata.from_schedule_batch(
420
421
422
            self,
            batch,
            forward_mode=ForwardMode.EXTEND,
Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
        )
        return self.model.forward(
425
            batch.input_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
426
427
            input_metadata.positions,
            input_metadata,
428
429
430
            input_metadata.pixel_values,
            input_metadata.image_sizes,
            input_metadata.image_offsets,
Lianmin Zheng's avatar
Lianmin Zheng committed
431
432
        )

433
    def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
Lianmin Zheng's avatar
Lianmin Zheng committed
434
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
435
436
437
            return self.forward_extend_multi_modal(batch)
        elif forward_mode == ForwardMode.DECODE:
            return self.forward_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
438
        elif forward_mode == ForwardMode.EXTEND:
439
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
440
441
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")
442
443
444
445
446
447
448
449
450
451
452


@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
            module = importlib.import_module(name)
            if hasattr(module, "EntryClass"):
453
                entry = module.EntryClass
454
455
456
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
457
                    for tmp in entry:
458
                        assert tmp.__name__ not in model_arch_name_to_cls
459
                        model_arch_name_to_cls[tmp.__name__] = tmp
460
                else:
461
                    assert entry.__name__ not in model_arch_name_to_cls
462
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
463
464
465

            # compat: some models such as chatglm has incorrect class set in config.json
            # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
466
467
468
            if hasattr(module, "EntryClassRemapping") and isinstance(
                module.EntryClassRemapping, list
            ):
Qubitium's avatar
Qubitium committed
469
470
                for remap in module.EntryClassRemapping:
                    if isinstance(remap, tuple) and len(remap) == 2:
471
                        assert remap[0] not in model_arch_name_to_cls
Qubitium's avatar
Qubitium committed
472
473
                        model_arch_name_to_cls[remap[0]] = remap[1]

474
475
476
477
478
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
479

480
481
482
483
484
485
486
487
488
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
489
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)