model_runner.py 16.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

Cody Yu's avatar
Cody Yu committed
18
import importlib
19
20
21
import importlib.resources
import logging
import pkgutil
22
import warnings
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
24
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.nn as nn
28
29
30
31
32
33
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
    BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
35
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
36
37
38
39
40
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
)
41
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
42

43
from sglang.global_config import global_config
44
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
45
46
47
48
49
50
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
from sglang.srt.model_config import AttentionArch
51
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
Lianmin Zheng's avatar
Lianmin Zheng committed
52
from sglang.srt.server_args import ServerArgs
53
54
from sglang.srt.utils import (
    get_available_gpu_memory,
55
    is_generation_model,
56
    is_llama3_405b_fp8,
57
    is_multimodal_model,
58
    monkey_patch_vllm_dummy_weight_loader,
59
    monkey_patch_vllm_p2p_access_check,
60
    monkey_patch_vllm_qvk_linear_loader,
61
)
62

Ying Sheng's avatar
Ying Sheng committed
63
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
64

Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67
68
69

class ModelRunner:
    def __init__(
        self,
        model_config,
70
71
72
73
74
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
75
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
76
    ):
77
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
80
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
83
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
84
        self.server_args = server_args
85
        self.is_multimodal_model = is_multimodal_model(self.model_config)
86
87
88
89
90
        global_server_args_dict.update(
            {
                "disable_flashinfer": server_args.disable_flashinfer,
                "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
                "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
91
                "enable_mla": server_args.enable_mla,
92
93
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95

        # Init torch distributed
96
        torch.cuda.set_device(self.gpu_id)
Ying Sheng's avatar
Ying Sheng committed
97
        logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
98
99
100
101

        if not server_args.enable_p2p_check:
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

102
103
104
105
        if server_args.nccl_init_addr:
            nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
106
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
109
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
110
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
111
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
112
113
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
114
        self.tp_group = get_tp_group()
115
116
117
        total_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
118
119

        if self.tp_size > 1:
120
            total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
121
            if total_local_gpu_memory < total_gpu_memory * 0.9:
122
123
124
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
125

126
        # Load the model and create memory pool
Lianmin Zheng's avatar
Lianmin Zheng committed
127
        self.load_model()
128
129
130
131
132
        self.init_memory_pool(
            total_gpu_memory,
            server_args.max_num_reqs,
            server_args.max_total_tokens,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
133
        self.init_cublas()
134
        self.init_flashinfer()
Lianmin Zheng's avatar
Lianmin Zheng committed
135

136
137
138
139
        if self.is_generation:
            # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
            # Capture cuda graphs
            self.init_cuda_graphs()
140

Lianmin Zheng's avatar
Lianmin Zheng committed
141
    def load_model(self):
142
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
143
            f"[gpu={self.gpu_id}] Load weight begin. "
144
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
145
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
146

147
        monkey_patch_vllm_dummy_weight_loader()
148
        device_config = DeviceConfig()
Lianmin Zheng's avatar
Lianmin Zheng committed
149
        load_config = LoadConfig(load_format=self.server_args.load_format)
150
        vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
151
152
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
153
154
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
155
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
156
            dtype=self.server_args.dtype,
157
158
159
            seed=42,
            skip_tokenizer_init=True,
        )
160

161
        if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
162
163
164
165
166
            # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
            self.model_config.hf_config.num_key_value_heads = 8
            vllm_model_config.hf_config.num_key_value_heads = 8
            monkey_patch_vllm_qvk_linear_loader()

Lianmin Zheng's avatar
Lianmin Zheng committed
167
        self.dtype = vllm_model_config.dtype
168
169
170
        if self.model_config.model_overide_args is not None:
            vllm_model_config.hf_config.update(self.model_config.model_overide_args)

171
172
173
174
175
176
177
178
179
        if (
            self.server_args.efficient_weight_load
            and "llama" in self.server_args.model_path.lower()
            and self.server_args.quantization == "fp8"
        ):
            from sglang.srt.model_loader.model_loader import get_model
        else:
            from vllm.model_executor.model_loader import get_model

180
181
182
183
184
        self.model = get_model(
            model_config=vllm_model_config,
            device_config=device_config,
            load_config=load_config,
            lora_config=None,
185
            multimodal_config=None,
186
187
            parallel_config=None,
            scheduler_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
188
            cache_config=None,
189
        )
190
191
192
193
        self.is_generation = is_generation_model(
            self.model_config.hf_config.architectures
        )

194
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
195
            f"[gpu={self.gpu_id}] Load weight end. "
196
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
197
            f"dtype={self.dtype}, "
198
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
199
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
200

Lianmin Zheng's avatar
Lianmin Zheng committed
201
    def profile_max_num_token(self, total_gpu_memory):
202
203
204
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
                * torch._utils._element_size(self.dtype)
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
                * torch._utils._element_size(self.dtype)
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
222
223
224
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
225
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
        return max_num_token

228
229
230
    def init_memory_pool(
        self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
    ):
231
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
232
233
234
235
236
237
238
239
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
                warnings.warn(
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
240

241
        if self.max_total_num_tokens <= 0:
242
            raise RuntimeError(
243
                "Not enough memory. Please try to increase --mem-fraction-static."
244
            )
245

Liangsheng Yin's avatar
Liangsheng Yin committed
246
        if max_num_reqs is None:
247
248
249
250
251
252
253
254
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                5120,
Liangsheng Yin's avatar
Liangsheng Yin committed
255
256
257
258
            )

        self.req_to_token_pool = ReqToTokenPool(
            max_num_reqs,
Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
            self.model_config.context_len + 8,
        )
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.dtype,
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
            logger.info("using MLA Triton implementaion, flashinfer is disabled")
            # FIXME: temporarily only Triton MLA is supported
            self.server_args.disable_flashinfer = True
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
                dtype=self.dtype,
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
283
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
284
            f"[gpu={self.gpu_id}] Memory pool end. "
285
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
286
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
287

Lianmin Zheng's avatar
Lianmin Zheng committed
288
289
290
291
292
293
294
295
296
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

297
    def init_flashinfer(self):
298
299
300
        if self.server_args.disable_flashinfer:
            self.flashinfer_prefill_wrapper_ragged = None
            self.flashinfer_prefill_wrapper_paged = None
301
            self.flashinfer_decode_wrapper = None
302
303
304
305
306
307
308
309
310
311
            return

        if not _grouped_size_compiled_for_decode_kernels(
            self.model_config.num_attention_heads // self.tp_size,
            self.model_config.get_num_kv_heads(self.tp_size),
        ):
            use_tensor_cores = True
        else:
            use_tensor_cores = False

312
313
        self.flashinfer_workspace_buffers = torch.empty(
            2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
314
315
        )
        self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
316
            self.flashinfer_workspace_buffers[0], "NHD"
317
318
        )
        self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
319
            self.flashinfer_workspace_buffers[1], "NHD"
320
321
        )
        self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
322
323
324
            self.flashinfer_workspace_buffers[0],
            "NHD",
            use_tensor_cores=use_tensor_cores,
325
        )
326

327
    def init_cuda_graphs(self):
328
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
329
330
331
332
333

        if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
            self.cuda_graph_runner = None
            return

334
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
335
            f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
336
        )
337
        batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
338
        self.cuda_graph_runner = CudaGraphRunner(
339
340
341
            self,
            max_batch_size_to_capture=max(batch_size_list),
            use_torch_compile=self.server_args.enable_torch_compile,
342
        )
343
344
        try:
            self.cuda_graph_runner.capture(batch_size_list)
345
        except RuntimeError as e:
346
            raise Exception(
347
348
349
350
351
352
                f"Capture cuda graph failed: {e}\n"
                "Possible solutions:\n"
                "1. disable torch compile by not using --enable-torch-compile\n"
                "2. disable cuda graph by --disable-cuda-graph\n"
                "3. set --mem-fraction-static to a smaller value\n"
                "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
353
            )
354

Lianmin Zheng's avatar
Lianmin Zheng committed
355
    @torch.inference_mode()
356
    def forward_decode(self, batch: ScheduleBatch):
357
358
359
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
            return self.cuda_graph_runner.replay(batch)

360
361
        input_metadata = InputMetadata.from_schedule_batch(
            self, batch, ForwardMode.DECODE
362
        )
363

364
365
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
366
367
368
        )

    @torch.inference_mode()
369
    def forward_extend(self, batch: ScheduleBatch):
370
371
        input_metadata = InputMetadata.from_schedule_batch(
            self, batch, forward_mode=ForwardMode.EXTEND
372
373
374
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
375
376
377
        )

    @torch.inference_mode()
378
    def forward_extend_multi_modal(self, batch: ScheduleBatch):
379
380
        input_metadata = InputMetadata.from_schedule_batch(
            self, batch, forward_mode=ForwardMode.EXTEND
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
        )
        return self.model.forward(
383
            batch.input_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
384
385
            input_metadata.positions,
            input_metadata,
386
387
388
            input_metadata.pixel_values,
            input_metadata.image_sizes,
            input_metadata.image_offsets,
Lianmin Zheng's avatar
Lianmin Zheng committed
389
390
        )

391
    def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
Lianmin Zheng's avatar
Lianmin Zheng committed
392
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
393
394
395
            return self.forward_extend_multi_modal(batch)
        elif forward_mode == ForwardMode.DECODE:
            return self.forward_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
396
        elif forward_mode == ForwardMode.EXTEND:
397
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
398
399
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")
400
401
402
403
404
405
406
407
408
409
410


@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
            module = importlib.import_module(name)
            if hasattr(module, "EntryClass"):
411
                entry = module.EntryClass
412
413
414
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
415
                    for tmp in entry:
416
                        assert tmp.__name__ not in model_arch_name_to_cls
417
                        model_arch_name_to_cls[tmp.__name__] = tmp
418
                else:
419
                    assert entry.__name__ not in model_arch_name_to_cls
420
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
421
422
423

            # compat: some models such as chatglm has incorrect class set in config.json
            # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
424
425
426
            if hasattr(module, "EntryClassRemapping") and isinstance(
                module.EntryClassRemapping, list
            ):
Qubitium's avatar
Qubitium committed
427
428
                for remap in module.EntryClassRemapping:
                    if isinstance(remap, tuple) and len(remap) == 2:
429
                        assert remap[0] not in model_arch_name_to_cls
Qubitium's avatar
Qubitium committed
430
431
                        model_arch_name_to_cls[remap[0]] = remap[1]

432
433
434
435
436
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
437

438
439
440
441
442
443
444
445
446
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
447
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)