model_runner.py 16.2 KB
Newer Older
Cody Yu's avatar
Cody Yu committed
1
import importlib
2
import importlib.resources
3
import inspect
4
5
import logging
import pkgutil
Lianmin Zheng's avatar
Lianmin Zheng committed
6
from dataclasses import dataclass
Cody Yu's avatar
Cody Yu committed
7
from functools import lru_cache
Liangsheng Yin's avatar
Liangsheng Yin committed
8
from typing import List
Lianmin Zheng's avatar
Lianmin Zheng committed
9
10
11
12

import numpy as np
import torch
from vllm.model_executor.layers.quantization.awq import AWQConfig
13
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
14
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
15
16
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.distributed import initialize_model_parallel
Lianmin Zheng's avatar
Lianmin Zheng committed
17

Liangsheng Yin's avatar
Liangsheng Yin committed
18
19
20
21
22
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
from sglang.utils import get_available_gpu_memory

23
24
25
26
27
QUANTIZATION_CONFIG_MAPPING = {
    "awq": AWQConfig,
    "gptq": GPTQConfig,
    "marlin": MarlinConfig,
}
28

Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
logger = logging.getLogger("model_runner")

Liangsheng Yin's avatar
Liangsheng Yin committed
31
# for server args in model endpoints
32
global_server_args_dict = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34


Cody Yu's avatar
Cody Yu committed
35
36
37
@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
38
39
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
40
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
41
42
        if not ispkg:
            module = importlib.import_module(name)
43
44
            if hasattr(module, "EntryClass"):
                model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
Cody Yu's avatar
Cody Yu committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    return model_arch_name_to_cls


def get_model_cls_by_arch_name(model_arch_names):
    model_arch_name_to_cls = import_model_classes()

    model_class = None
    for arch in model_arch_names:
        if arch in model_arch_name_to_cls:
            model_class = model_arch_name_to_cls[arch]
            break
    else:
        raise ValueError(
            f"Unsupported architectures: {arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_class


Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
@dataclass
class InputMetadata:
    model_runner: "ModelRunner"
    forward_mode: ForwardMode
    batch_size: int
    total_num_tokens: int
    max_seq_len: int
    req_pool_indices: torch.Tensor
    start_loc: torch.Tensor
    seq_lens: torch.Tensor
    prefix_lens: torch.Tensor
    positions: torch.Tensor
    req_to_token_pool: ReqToTokenPool
    token_to_kv_pool: TokenToKVPool

    # for extend
    extend_seq_lens: torch.Tensor = None
    extend_start_loc: torch.Tensor = None
    max_extend_len: int = 0

    out_cache_loc: torch.Tensor = None
    out_cache_cont_start: torch.Tensor = None
    out_cache_cont_end: torch.Tensor = None

    other_kv_index: torch.Tensor = None
89
    return_logprob: bool = False
90
    top_logprobs_nums: List[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
91
92
93
94
95
96
97
98
99
100

    # for flashinfer
    qo_indptr: torch.Tensor = None
    kv_indptr: torch.Tensor = None
    kv_indices: torch.Tensor = None
    kv_last_page_len: torch.Tensor = None
    prefill_wrapper = None
    decode_wrapper = None

    def init_flashinfer_args(self, tp_size):
101
102
103
104
105
        from flashinfer import (
            BatchDecodeWithPagedKVCacheWrapper,
            BatchPrefillWithPagedKVCacheWrapper,
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
106
107
108
109
        self.kv_indptr = torch.zeros(
            (self.batch_size + 1,), dtype=torch.int32, device="cuda"
        )
        self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
110
111
112
113
114
        self.kv_last_page_len = torch.ones(
            (self.batch_size,), dtype=torch.int32, device="cuda"
        )
        req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
        seq_lens_cpu = self.seq_lens.cpu().numpy()
Lianmin Zheng's avatar
Lianmin Zheng committed
115
116
117
        self.kv_indices = torch.cat(
            [
                self.req_to_token_pool.req_to_token[
118
                    req_pool_indices_cpu[i]: seq_lens_cpu[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
119
120
121
122
123
124
                ]
                for i in range(self.batch_size)
            ],
            dim=0,
        ).contiguous()

125
126
127
        workspace_buffer = torch.empty(
            32 * 1024 * 1024, dtype=torch.int8, device="cuda"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
130
131
132
133
134
135
        if (
            self.forward_mode == ForwardMode.PREFILL
            or self.forward_mode == ForwardMode.EXTEND
        ):
            self.qo_indptr = torch.zeros(
                (self.batch_size + 1,), dtype=torch.int32, device="cuda"
            )
            self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
136
137
138
            self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                workspace_buffer, "NHD"
            )
139
            args = [
Lianmin Zheng's avatar
Lianmin Zheng committed
140
                self.qo_indptr,
141
142
143
                self.kv_indptr,
                self.kv_indices,
                self.kv_last_page_len,
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
                self.model_runner.model_config.num_attention_heads // tp_size,
                self.model_runner.model_config.num_key_value_heads // tp_size,
146
                self.model_runner.model_config.head_dim
147
148
149
            ]

            self.prefill_wrapper.begin_forward(*args)
Lianmin Zheng's avatar
Lianmin Zheng committed
150
        else:
151
152
153
            self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
                workspace_buffer, "NHD"
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
154
155
            self.decode_wrapper.begin_forward(
                self.kv_indptr,
156
                self.kv_indices,
Lianmin Zheng's avatar
Lianmin Zheng committed
157
158
159
160
161
162
163
164
165
166
167
168
                self.kv_last_page_len,
                self.model_runner.model_config.num_attention_heads // tp_size,
                self.model_runner.model_config.num_key_value_heads // tp_size,
                self.model_runner.model_config.head_dim,
                1,
                "NONE",
                "float16",
            )

    def init_extend_args(self):
        self.extend_seq_lens = self.seq_lens - self.prefix_lens
        self.extend_start_loc = torch.zeros_like(self.seq_lens)
169
        self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
Lianmin Zheng's avatar
Lianmin Zheng committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        self.max_extend_len = int(torch.max(self.extend_seq_lens))

    @classmethod
    def create(
        cls,
        model_runner,
        tp_size,
        forward_mode,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
        out_cache_cont_start=None,
        out_cache_cont_end=None,
Liangsheng Yin's avatar
Liangsheng Yin committed
185
        top_logprobs_nums=None,
186
        return_logprob=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
187
188
189
190
191
192
193
194
195
196
197
198
199
    ):
        batch_size = len(req_pool_indices)
        start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
        start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
        total_num_tokens = int(torch.sum(seq_lens))
        max_seq_len = int(torch.max(seq_lens))

        if forward_mode == ForwardMode.DECODE:
            positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
            other_kv_index = model_runner.req_to_token_pool.req_to_token[
                req_pool_indices[0], seq_lens[0] - 1
            ].item()
        else:
200
201
202
            seq_lens_cpu = seq_lens.cpu().numpy()
            prefix_lens_cpu = prefix_lens.cpu().numpy()
            position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
Lianmin Zheng's avatar
Lianmin Zheng committed
203
204
205
206
            positions = torch.tensor(
                np.concatenate(
                    [
                        np.arange(
207
208
                            prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
                            seq_lens_cpu[i] + position_ids_offsets_cpu[i],
Lianmin Zheng's avatar
Lianmin Zheng committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
                        )
                        for i in range(batch_size)
                    ],
                    axis=0,
                ),
                device="cuda",
            )
            other_kv_index = None

        ret = cls(
            model_runner=model_runner,
            forward_mode=forward_mode,
            batch_size=batch_size,
            total_num_tokens=total_num_tokens,
            max_seq_len=max_seq_len,
            req_pool_indices=req_pool_indices,
            start_loc=start_loc,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            positions=positions,
            req_to_token_pool=model_runner.req_to_token_pool,
            token_to_kv_pool=model_runner.token_to_kv_pool,
            out_cache_loc=out_cache_loc,
            out_cache_cont_start=out_cache_cont_start,
            out_cache_cont_end=out_cache_cont_end,
            other_kv_index=other_kv_index,
235
236
            return_logprob=return_logprob,
            top_logprobs_nums=top_logprobs_nums,
Lianmin Zheng's avatar
Lianmin Zheng committed
237
238
239
240
241
        )

        if forward_mode == ForwardMode.EXTEND:
            ret.init_extend_args()

242
        if global_server_args_dict.get("enable_flashinfer", False):
Lianmin Zheng's avatar
Lianmin Zheng committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            ret.init_flashinfer_args(tp_size)

        return ret


class ModelRunner:
    def __init__(
        self,
        model_config,
        mem_fraction_static,
        tp_rank,
        tp_size,
        nccl_port,
        load_format="auto",
        trust_remote_code=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
258
        server_args_dict: dict = {},
Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
261
262
263
264
265
266
267
    ):
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
        self.load_format = load_format
        self.trust_remote_code = trust_remote_code

Liangsheng Yin's avatar
Liangsheng Yin committed
268
269
        global global_server_args_dict
        global_server_args_dict = server_args_dict
Lianmin Zheng's avatar
Lianmin Zheng committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

        # Init torch distributed
        torch.cuda.set_device(self.tp_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
            init_method=f"tcp://127.0.0.1:{self.nccl_port}",
        )

        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)

        total_gpu_memory = get_available_gpu_memory(
            self.tp_rank, distributed=self.tp_size > 1
        ) * (1 << 30)
        self.load_model()
        self.init_memory_pool(total_gpu_memory)

        self.is_multimodal_model = is_multimodal_model(self.model_config)

    def load_model(self):
        """See also vllm/model_executor/model_loader.py::get_model"""
        # Select model class
        architectures = getattr(self.model_config.hf_config, "architectures", [])
Cody Yu's avatar
Cody Yu committed
294
        model_class = get_model_cls_by_arch_name(architectures)
295
        logger.info(f"Rank {self.tp_rank}: load weight begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
296

Lianmin Zheng's avatar
Lianmin Zheng committed
297
        # Load weights
298
        quant_config = None
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

        quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
            # compat: autogptq >=0.8.0 use checkpoint_format: str
            # compat: autogptq <=0.7.1 is_marlin_format: bool
            is_format_marlin = quant_cfg.get(
                "checkpoint_format"
            ) == "marlin" or quant_cfg.get("is_marlin_format", False)

            # Use marlin if the GPTQ model is serialized in marlin format.
            if quant_method == "gptq" and is_format_marlin:
                quant_method = "marlin"

            quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)

            if quant_config_class is None:
                raise ValueError(f"Unsupported quantization method: {quant_method}")

            quant_config = quant_config_class.from_config(quant_cfg)
            logger.info(f"quant_config: {quant_config}")

321
        with set_default_torch_dtype(torch.float16):
Lianmin Zheng's avatar
Lianmin Zheng committed
322
323
            with torch.device("cuda"):
                model = model_class(
324
                    config=self.model_config.hf_config, quant_config=quant_config
Lianmin Zheng's avatar
Lianmin Zheng committed
325
326
327
328
329
330
331
                )
            model.load_weights(
                self.model_config.path,
                cache_dir=None,
                load_format=self.load_format,
                revision=None,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
332
        self.model = model.eval()
Lianmin Zheng's avatar
Lianmin Zheng committed
333

334
        logger.info(f"Rank {self.tp_rank}: load weight end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
335

Lianmin Zheng's avatar
Lianmin Zheng committed
336
337
338
339
    def profile_max_num_token(self, total_gpu_memory):
        available_gpu_memory = get_available_gpu_memory(
            self.tp_rank, distributed=self.tp_size > 1
        ) * (1 << 30)
Liangsheng Yin's avatar
Liangsheng Yin committed
340
        head_dim = self.model_config.head_dim
Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
343
344
345
346
347
348
349
350
        head_num = self.model_config.num_key_value_heads // self.tp_size
        cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
        max_num_token = int(rest_memory // cell_size)
        return max_num_token

    def init_memory_pool(self, total_gpu_memory):
        self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
351
352

        if self.max_total_num_token <= 0:
353
354
355
            raise RuntimeError(
                "Not enought memory. " "Please try to increase --mem-fraction-static."
            )
356

Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
359
360
361
362
363
364
        self.req_to_token_pool = ReqToTokenPool(
            int(self.max_total_num_token / self.model_config.context_len * 256),
            self.model_config.context_len + 8,
        )
        self.token_to_kv_pool = TokenToKVPool(
            self.max_total_num_token,
            dtype=torch.float16,
            head_num=self.model_config.num_key_value_heads // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
365
            head_dim=self.model_config.head_dim,
Lianmin Zheng's avatar
Lianmin Zheng committed
366
367
368
369
            layer_num=self.model_config.num_hidden_layers,
        )

    @torch.inference_mode()
370
    def forward_prefill(self, batch: Batch):
Lianmin Zheng's avatar
Lianmin Zheng committed
371
372
373
374
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.PREFILL,
            tp_size=self.tp_size,
375
376
377
378
379
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
Liangsheng Yin's avatar
Liangsheng Yin committed
380
            top_logprobs_nums=batch.top_logprobs_nums,
381
382
383
384
            return_logprob=batch.return_logprob,
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
385
386
387
        )

    @torch.inference_mode()
388
    def forward_extend(self, batch: Batch):
Lianmin Zheng's avatar
Lianmin Zheng committed
389
390
391
392
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
            tp_size=self.tp_size,
393
394
395
396
397
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
Liangsheng Yin's avatar
Liangsheng Yin committed
398
            top_logprobs_nums=batch.top_logprobs_nums,
399
400
401
402
            return_logprob=batch.return_logprob,
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
403
404
405
        )

    @torch.inference_mode()
406
    def forward_decode(self, batch: Batch):
Lianmin Zheng's avatar
Lianmin Zheng committed
407
408
409
410
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.DECODE,
            tp_size=self.tp_size,
411
412
413
414
415
416
417
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
            out_cache_cont_start=batch.out_cache_cont_start,
            out_cache_cont_end=batch.out_cache_cont_end,
Liangsheng Yin's avatar
Liangsheng Yin committed
418
            top_logprobs_nums=batch.top_logprobs_nums,
419
420
421
422
            return_logprob=batch.return_logprob,
        )
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
425
        )

    @torch.inference_mode()
426
    def forward_extend_multi_modal(self, batch: Batch):
Lianmin Zheng's avatar
Lianmin Zheng committed
427
428
429
430
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
            tp_size=self.tp_size,
431
432
433
434
435
            req_pool_indices=batch.req_pool_indices,
            seq_lens=batch.seq_lens,
            prefix_lens=batch.prefix_lens,
            position_ids_offsets=batch.position_ids_offsets,
            out_cache_loc=batch.out_cache_loc,
Liangsheng Yin's avatar
Liangsheng Yin committed
436
            top_logprobs_nums=batch.top_logprobs_nums,
437
            return_logprob=batch.return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
438
439
        )
        return self.model.forward(
440
            batch.input_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
            input_metadata.positions,
            input_metadata,
443
444
445
            batch.pixel_values,
            batch.image_sizes,
            batch.image_offsets,
Lianmin Zheng's avatar
Lianmin Zheng committed
446
447
        )

448
    def forward(self, batch: Batch, forward_mode: ForwardMode):
Lianmin Zheng's avatar
Lianmin Zheng committed
449
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
450
451
452
            return self.forward_extend_multi_modal(batch)
        elif forward_mode == ForwardMode.DECODE:
            return self.forward_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
453
        elif forward_mode == ForwardMode.EXTEND:
454
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
455
        elif forward_mode == ForwardMode.PREFILL:
456
            return self.forward_prefill(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
457
458
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")