model_runner.py 15.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

Cody Yu's avatar
Cody Yu committed
18
import importlib
19
20
21
import importlib.resources
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
22
from functools import lru_cache
23
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25

import torch
26
import torch.nn as nn
27
28
29
30
31
32
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
    BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
33
34
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
35
36
37
38
39
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
)
40
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
41

42
from sglang.global_config import global_config
43
from sglang.srt.managers.schedule_batch import (
44
45
46
47
48
    Batch,
    ForwardMode,
    InputMetadata,
    global_server_args_dict,
)
49
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
Lianmin Zheng's avatar
Lianmin Zheng committed
50
from sglang.srt.server_args import ServerArgs
51
52
from sglang.srt.utils import (
    get_available_gpu_memory,
53
    is_llama3_405b_fp8,
54
    is_multimodal_model,
55
    monkey_patch_vllm_dummy_weight_loader,
56
    monkey_patch_vllm_p2p_access_check,
57
    monkey_patch_vllm_qvk_linear_loader,
58
)
59

Ying Sheng's avatar
Ying Sheng committed
60
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
61

Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
64
65
66

class ModelRunner:
    def __init__(
        self,
        model_config,
67
68
69
70
71
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
72
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
73
    ):
74
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
77
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
80
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
81
        self.server_args = server_args
82
        self.is_multimodal_model = is_multimodal_model(self.model_config)
83
84
85
86
87
88
89
        global_server_args_dict.update(
            {
                "disable_flashinfer": server_args.disable_flashinfer,
                "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
                "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
90
91

        # Init torch distributed
92
        torch.cuda.set_device(self.gpu_id)
Ying Sheng's avatar
Ying Sheng committed
93
        logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
94
95
96
97

        if not server_args.enable_p2p_check:
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

98
99
100
101
        if server_args.nccl_init_addr:
            nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
Lianmin Zheng's avatar
Lianmin Zheng committed
102
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
103
104
105
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
106
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
107
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
110
        self.tp_group = get_tp_group()
111
112
113
        total_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
114
115

        if self.tp_size > 1:
116
            total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
117
            if total_local_gpu_memory < total_gpu_memory * 0.9:
118
119
120
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
121

122
        # Load the model and create memory pool
Lianmin Zheng's avatar
Lianmin Zheng committed
123
        self.load_model()
Liangsheng Yin's avatar
Liangsheng Yin committed
124
        self.init_memory_pool(total_gpu_memory, server_args.max_num_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
125
        self.init_cublas()
126
        self.init_flash_infer()
Lianmin Zheng's avatar
Lianmin Zheng committed
127

128
129
130
        # Capture cuda graphs
        self.init_cuda_graphs()

Lianmin Zheng's avatar
Lianmin Zheng committed
131
    def load_model(self):
132
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
133
            f"[gpu={self.gpu_id}] Load weight begin. "
134
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
135
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
136

137
        monkey_patch_vllm_dummy_weight_loader()
138
        device_config = DeviceConfig()
Lianmin Zheng's avatar
Lianmin Zheng committed
139
        load_config = LoadConfig(load_format=self.server_args.load_format)
140
        vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
143
144
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
145
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
146
            dtype=self.server_args.dtype,
147
148
149
            seed=42,
            skip_tokenizer_init=True,
        )
150

151
        if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
152
153
154
155
156
            # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
            self.model_config.hf_config.num_key_value_heads = 8
            vllm_model_config.hf_config.num_key_value_heads = 8
            monkey_patch_vllm_qvk_linear_loader()

Lianmin Zheng's avatar
Lianmin Zheng committed
157
        self.dtype = vllm_model_config.dtype
158
159
160
        if self.model_config.model_overide_args is not None:
            vllm_model_config.hf_config.update(self.model_config.model_overide_args)

161
162
163
164
165
166
167
168
169
        if (
            self.server_args.efficient_weight_load
            and "llama" in self.server_args.model_path.lower()
            and self.server_args.quantization == "fp8"
        ):
            from sglang.srt.model_loader.model_loader import get_model
        else:
            from vllm.model_executor.model_loader import get_model

170
171
172
173
174
        self.model = get_model(
            model_config=vllm_model_config,
            device_config=device_config,
            load_config=load_config,
            lora_config=None,
175
            multimodal_config=None,
176
177
            parallel_config=None,
            scheduler_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
178
            cache_config=None,
179
        )
180
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
181
            f"[gpu={self.gpu_id}] Load weight end. "
182
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
183
            f"dtype={self.dtype}, "
184
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
185
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
186

Lianmin Zheng's avatar
Lianmin Zheng committed
187
    def profile_max_num_token(self, total_gpu_memory):
188
189
190
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
191
        head_dim = self.model_config.head_dim
192
        head_num = self.model_config.get_num_kv_heads(self.tp_size)
Ying Sheng's avatar
Ying Sheng committed
193
194
195
196
197
198
199
        cell_size = (
            head_num
            * head_dim
            * self.model_config.num_hidden_layers
            * 2
            * torch._utils._element_size(self.dtype)
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
200
201
202
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
203
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
204
205
        return max_num_token

Liangsheng Yin's avatar
Liangsheng Yin committed
206
    def init_memory_pool(self, total_gpu_memory, max_num_reqs=None):
207
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
208

209
        if self.max_total_num_tokens <= 0:
210
            raise RuntimeError(
211
                "Not enough memory. Please try to increase --mem-fraction-static."
212
            )
213

Liangsheng Yin's avatar
Liangsheng Yin committed
214
        if max_num_reqs is None:
215
216
217
218
219
220
221
222
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                5120,
Liangsheng Yin's avatar
Liangsheng Yin committed
223
224
225
226
            )

        self.req_to_token_pool = ReqToTokenPool(
            max_num_reqs,
Lianmin Zheng's avatar
Lianmin Zheng committed
227
228
229
            self.model_config.context_len + 8,
        )
        self.token_to_kv_pool = TokenToKVPool(
230
            self.max_total_num_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
231
            dtype=self.dtype,
Qubitium's avatar
Qubitium committed
232
            head_num=self.model_config.get_num_kv_heads(self.tp_size),
Liangsheng Yin's avatar
Liangsheng Yin committed
233
            head_dim=self.model_config.head_dim,
Lianmin Zheng's avatar
Lianmin Zheng committed
234
235
            layer_num=self.model_config.num_hidden_layers,
        )
236
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
237
            f"[gpu={self.gpu_id}] Memory pool end. "
238
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
239
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
240

Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
243
244
245
246
247
248
249
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

250
    def init_flash_infer(self):
251
252
253
        if self.server_args.disable_flashinfer:
            self.flashinfer_prefill_wrapper_ragged = None
            self.flashinfer_prefill_wrapper_paged = None
254
            self.flashinfer_decode_wrapper = None
255
256
257
258
259
260
261
262
263
264
            return

        if not _grouped_size_compiled_for_decode_kernels(
            self.model_config.num_attention_heads // self.tp_size,
            self.model_config.get_num_kv_heads(self.tp_size),
        ):
            use_tensor_cores = True
        else:
            use_tensor_cores = False

265
266
        self.flashinfer_workspace_buffers = torch.empty(
            2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
267
268
        )
        self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
269
            self.flashinfer_workspace_buffers[0], "NHD"
270
271
        )
        self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
272
            self.flashinfer_workspace_buffers[1], "NHD"
273
274
        )
        self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
275
276
277
            self.flashinfer_workspace_buffers[0],
            "NHD",
            use_tensor_cores=use_tensor_cores,
278
        )
279

280
    def init_cuda_graphs(self):
281
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
282
283
284
285
286

        if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
            self.cuda_graph_runner = None
            return

287
        logger.info(
Ying Sheng's avatar
Ying Sheng committed
288
            f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
289
        )
290
        batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
291
        self.cuda_graph_runner = CudaGraphRunner(
292
293
294
            self,
            max_batch_size_to_capture=max(batch_size_list),
            use_torch_compile=self.server_args.enable_torch_compile,
295
        )
296
297
        try:
            self.cuda_graph_runner.capture(batch_size_list)
298
        except RuntimeError as e:
299
            raise Exception(
300
                f"Capture cuda graph failed: {e}. Possible solutions:\n"
301
302
303
304
                f"1. disable cuda graph by --disable-cuda-graph\n"
                f"2. set --mem-fraction-static to a smaller value\n"
                f"Open an issue on GitHub with reproducible scripts if you need help.\n"
            )
305

Lianmin Zheng's avatar
Lianmin Zheng committed
306
    @torch.inference_mode()
307
308
309
310
    def forward_decode(self, batch: Batch):
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
            return self.cuda_graph_runner.replay(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
311
312
        input_metadata = InputMetadata.create(
            self,
313
            forward_mode=ForwardMode.DECODE,
314
315
316
317
318
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
Liangsheng Yin's avatar
Liangsheng Yin committed
319
            top_logprobs_nums=batch.top_logprobs_nums,
320
321
322
323
            return_logprob=batch.return_logprob,
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
324
325
326
        )

    @torch.inference_mode()
327
    def forward_extend(self, batch: Batch):
Lianmin Zheng's avatar
Lianmin Zheng committed
328
329
        input_metadata = InputMetadata.create(
            self,
330
            forward_mode=ForwardMode.EXTEND,
331
332
333
334
335
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
Liangsheng Yin's avatar
Liangsheng Yin committed
336
            top_logprobs_nums=batch.top_logprobs_nums,
337
338
339
340
            return_logprob=batch.return_logprob,
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
343
        )

    @torch.inference_mode()
344
    def forward_extend_multi_modal(self, batch: Batch):
Lianmin Zheng's avatar
Lianmin Zheng committed
345
346
347
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
348
349
350
351
352
353
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
            return_logprob=batch.return_logprob,
354
            top_logprobs_nums=batch.top_logprobs_nums,
Lianmin Zheng's avatar
Lianmin Zheng committed
355
356
        )
        return self.model.forward(
357
            batch.input_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
358
359
            input_metadata.positions,
            input_metadata,
360
361
362
            batch.pixel_values,
            batch.image_sizes,
            batch.image_offsets,
Lianmin Zheng's avatar
Lianmin Zheng committed
363
364
        )

365
    def forward(self, batch: Batch, forward_mode: ForwardMode):
Lianmin Zheng's avatar
Lianmin Zheng committed
366
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
367
368
369
            return self.forward_extend_multi_modal(batch)
        elif forward_mode == ForwardMode.DECODE:
            return self.forward_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
370
        elif forward_mode == ForwardMode.EXTEND:
371
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
372
373
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")
374
375
376
377
378
379
380
381
382
383
384


@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
            module = importlib.import_module(name)
            if hasattr(module, "EntryClass"):
385
                entry = module.EntryClass
386
387
388
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
389
390
                    for tmp in entry:
                        model_arch_name_to_cls[tmp.__name__] = tmp
391
392
                else:
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
393
394
395

            # compat: some models such as chatglm has incorrect class set in config.json
            # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
396
397
398
            if hasattr(module, "EntryClassRemapping") and isinstance(
                module.EntryClassRemapping, list
            ):
Qubitium's avatar
Qubitium committed
399
400
401
402
                for remap in module.EntryClassRemapping:
                    if isinstance(remap, tuple) and len(remap) == 2:
                        model_arch_name_to_cls[remap[0]] = remap[1]

403
404
405
406
407
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
408

409
410
411
412
413
414
415
416
417
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
418
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)