model_runner.py 17.1 KB
Newer Older
Cody Yu's avatar
Cody Yu committed
1
import importlib
2
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
3
from dataclasses import dataclass
Cody Yu's avatar
Cody Yu committed
4
5
from functools import lru_cache
from pathlib import Path
Lianmin Zheng's avatar
Lianmin Zheng committed
6
7
8
9
10
11
12
13
14
from typing import List

import numpy as np
import torch
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
from sglang.utils import get_available_gpu_memory
from vllm.model_executor.layers.quantization.awq import AWQConfig
15
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel

19
import sglang
20
21

QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
22

Lianmin Zheng's avatar
Lianmin Zheng committed
23
24
25
logger = logging.getLogger("model_runner")


Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
29
# for model_mode
global_model_mode: List[str] = []


Cody Yu's avatar
Cody Yu committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
        module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
        if hasattr(module, "EntryClass"):
            model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
    return model_arch_name_to_cls


def get_model_cls_by_arch_name(model_arch_names):
    model_arch_name_to_cls = import_model_classes()

    model_class = None
    for arch in model_arch_names:
        if arch in model_arch_name_to_cls:
            model_class = model_arch_name_to_cls[arch]
            break
    else:
        raise ValueError(
            f"Unsupported architectures: {arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_class


Lianmin Zheng's avatar
Lianmin Zheng committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@dataclass
class InputMetadata:
    model_runner: "ModelRunner"
    forward_mode: ForwardMode
    batch_size: int
    total_num_tokens: int
    max_seq_len: int
    req_pool_indices: torch.Tensor
    start_loc: torch.Tensor
    seq_lens: torch.Tensor
    prefix_lens: torch.Tensor
    positions: torch.Tensor
    req_to_token_pool: ReqToTokenPool
    token_to_kv_pool: TokenToKVPool

    # for extend
    extend_seq_lens: torch.Tensor = None
    extend_start_loc: torch.Tensor = None
    max_extend_len: int = 0

    out_cache_loc: torch.Tensor = None
    out_cache_cont_start: torch.Tensor = None
    out_cache_cont_end: torch.Tensor = None

    other_kv_index: torch.Tensor = None
81
    return_logprob: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
84
85
86
87
88
89
90
91
92

    # for flashinfer
    use_flashinfer: bool = False
    qo_indptr: torch.Tensor = None
    kv_indptr: torch.Tensor = None
    kv_indices: torch.Tensor = None
    kv_last_page_len: torch.Tensor = None
    prefill_wrapper = None
    decode_wrapper = None

    def init_flashinfer_args(self, tp_size):
93
94
95
96
97
        from flashinfer import (
            BatchDecodeWithPagedKVCacheWrapper,
            BatchPrefillWithPagedKVCacheWrapper,
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        self.kv_indptr = torch.zeros(
            (self.batch_size + 1,), dtype=torch.int32, device="cuda"
        )
        self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
        self.kv_indices = torch.cat(
            [
                self.req_to_token_pool.req_to_token[
                    self.req_pool_indices[i].item(), : self.seq_lens[i].item()
                ]
                for i in range(self.batch_size)
            ],
            dim=0,
        ).contiguous()
        self.kv_last_page_len = torch.ones(
            (self.batch_size,), dtype=torch.int32, device="cuda"
        )

115
        workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
118
119
120
121
122
123
        if (
            self.forward_mode == ForwardMode.PREFILL
            or self.forward_mode == ForwardMode.EXTEND
        ):
            self.qo_indptr = torch.zeros(
                (self.batch_size + 1,), dtype=torch.int32, device="cuda"
            )
            self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
124
            self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
Lianmin Zheng's avatar
Lianmin Zheng committed
125
126
            self.prefill_wrapper.begin_forward(
                self.qo_indptr,
127
128
129
                self.kv_indptr,
                self.kv_indices,
                self.kv_last_page_len,
Lianmin Zheng's avatar
Lianmin Zheng committed
130
131
132
133
                self.model_runner.model_config.num_attention_heads // tp_size,
                self.model_runner.model_config.num_key_value_heads // tp_size,
            )
        else:
134
            self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
            self.decode_wrapper.begin_forward(
                self.kv_indptr,
137
                self.kv_indices,
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
140
141
142
143
144
145
146
147
148
149
                self.kv_last_page_len,
                self.model_runner.model_config.num_attention_heads // tp_size,
                self.model_runner.model_config.num_key_value_heads // tp_size,
                self.model_runner.model_config.head_dim,
                1,
                "NONE",
                "float16",
            )

    def init_extend_args(self):
        self.extend_seq_lens = self.seq_lens - self.prefix_lens
        self.extend_start_loc = torch.zeros_like(self.seq_lens)
150
        self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
Lianmin Zheng's avatar
Lianmin Zheng committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        self.max_extend_len = int(torch.max(self.extend_seq_lens))

    @classmethod
    def create(
        cls,
        model_runner,
        tp_size,
        forward_mode,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
        out_cache_cont_start=None,
        out_cache_cont_end=None,
166
        return_logprob=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    ):
        batch_size = len(req_pool_indices)
        start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
        start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
        total_num_tokens = int(torch.sum(seq_lens))
        max_seq_len = int(torch.max(seq_lens))

        if forward_mode == ForwardMode.DECODE:
            positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
            other_kv_index = model_runner.req_to_token_pool.req_to_token[
                req_pool_indices[0], seq_lens[0] - 1
            ].item()
        else:
            seq_lens_np = seq_lens.cpu().numpy()
            prefix_lens_np = prefix_lens.cpu().numpy()
            position_ids_offsets_np = position_ids_offsets.cpu().numpy()
            positions = torch.tensor(
                np.concatenate(
                    [
                        np.arange(
                            prefix_lens_np[i] + position_ids_offsets_np[i],
                            seq_lens_np[i] + position_ids_offsets_np[i],
                        )
                        for i in range(batch_size)
                    ],
                    axis=0,
                ),
                device="cuda",
            )
            other_kv_index = None

        ret = cls(
            model_runner=model_runner,
            forward_mode=forward_mode,
            batch_size=batch_size,
            total_num_tokens=total_num_tokens,
            max_seq_len=max_seq_len,
            req_pool_indices=req_pool_indices,
            start_loc=start_loc,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            positions=positions,
            req_to_token_pool=model_runner.req_to_token_pool,
            token_to_kv_pool=model_runner.token_to_kv_pool,
            out_cache_loc=out_cache_loc,
            out_cache_cont_start=out_cache_cont_start,
            out_cache_cont_end=out_cache_cont_end,
214
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
            other_kv_index=other_kv_index,
        )

        if forward_mode == ForwardMode.EXTEND:
            ret.init_extend_args()

        ret.use_flashinfer = "flashinfer" in model_runner.model_mode
        if ret.use_flashinfer:
            ret.init_flashinfer_args(tp_size)

        return ret


class ModelRunner:
    def __init__(
        self,
        model_config,
        mem_fraction_static,
        tp_rank,
        tp_size,
        nccl_port,
        load_format="auto",
        trust_remote_code=True,
        model_mode: List[str] = (),
    ):
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
        self.load_format = load_format
        self.trust_remote_code = trust_remote_code
        self.model_mode = model_mode

        global global_model_mode
        global_model_mode = model_mode

        # Init torch distributed
        torch.cuda.set_device(self.tp_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
            init_method=f"tcp://127.0.0.1:{self.nccl_port}",
        )

        # A small all_reduce for warmup.
        if self.tp_size > 1:
            torch.distributed.all_reduce(torch.zeros(1).cuda())
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)

        total_gpu_memory = get_available_gpu_memory(
            self.tp_rank, distributed=self.tp_size > 1
        ) * (1 << 30)
        self.load_model()
        self.init_memory_pool(total_gpu_memory)

        self.is_multimodal_model = is_multimodal_model(self.model_config)

    def load_model(self):
        """See also vllm/model_executor/model_loader.py::get_model"""
        # Select model class
        architectures = getattr(self.model_config.hf_config, "architectures", [])
Cody Yu's avatar
Cody Yu committed
278
        model_class = get_model_cls_by_arch_name(architectures)
279
        logger.info(f"Rank {self.tp_rank}: load weight begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
280

Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
283
284
285
286
287
288
        # Load weights
        linear_method = None
        with _set_default_torch_dtype(torch.float16):
            with torch.device("cuda"):
                hf_quant_config = getattr(
                    self.model_config.hf_config, "quantization_config", None
                )
                if hf_quant_config is not None:
289
290
291
                    quant_config_class = QUANTIONCONFIG_MAPPING.get(
                        hf_quant_config["quant_method"]
                    )
292
                    if quant_config_class is None:
293
294
295
                        raise ValueError(
                            f"Unsupported quantization method: {hf_quant_config['quant_method']}"
                        )
296
                    quant_config = quant_config_class.from_config(hf_quant_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
297
                    logger.info(f"quant_config: {quant_config}")
Lianmin Zheng's avatar
Lianmin Zheng committed
298
299
300
301
302
303
304
305
306
307
                    linear_method = quant_config.get_linear_method()
                model = model_class(
                    config=self.model_config.hf_config, linear_method=linear_method
                )
            model.load_weights(
                self.model_config.path,
                cache_dir=None,
                load_format=self.load_format,
                revision=None,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
308
        self.model = model.eval()
Lianmin Zheng's avatar
Lianmin Zheng committed
309

310
        logger.info(f"Rank {self.tp_rank}: load weight end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
311

Lianmin Zheng's avatar
Lianmin Zheng committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    def profile_max_num_token(self, total_gpu_memory):
        available_gpu_memory = get_available_gpu_memory(
            self.tp_rank, distributed=self.tp_size > 1
        ) * (1 << 30)
        head_dim = (
            self.model_config.hidden_size // self.model_config.num_attention_heads
        )
        head_num = self.model_config.num_key_value_heads // self.tp_size
        cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
        max_num_token = int(rest_memory // cell_size)
        return max_num_token

    def init_memory_pool(self, total_gpu_memory):
        self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
329
330

        if self.max_total_num_token <= 0:
331
332
333
            raise RuntimeError(
                "Not enought memory. " "Please try to increase --mem-fraction-static."
            )
334

Lianmin Zheng's avatar
Lianmin Zheng committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        self.req_to_token_pool = ReqToTokenPool(
            int(self.max_total_num_token / self.model_config.context_len * 256),
            self.model_config.context_len + 8,
        )
        self.token_to_kv_pool = TokenToKVPool(
            self.max_total_num_token,
            dtype=torch.float16,
            head_num=self.model_config.num_key_value_heads // self.tp_size,
            head_dim=self.model_config.hidden_size
            // self.model_config.num_attention_heads,
            layer_num=self.model_config.num_hidden_layers,
        )

    @torch.inference_mode()
    def forward_prefill(
        self,
        input_ids,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
357
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
358
359
360
361
362
363
364
365
366
367
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.PREFILL,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
368
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
369
370
371
372
373
374
375
376
377
378
379
380
        )
        return self.model.forward(input_ids, input_metadata.positions, input_metadata)

    @torch.inference_mode()
    def forward_extend(
        self,
        input_ids,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
381
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
384
385
386
387
388
389
390
391
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
392
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        )
        return self.model.forward(input_ids, input_metadata.positions, input_metadata)

    @torch.inference_mode()
    def forward_decode(
        self,
        input_ids,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
        out_cache_cont_start,
        out_cache_cont_end,
Cody Yu's avatar
Cody Yu committed
407
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
408
409
410
411
412
413
414
415
416
417
418
419
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.DECODE,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
            out_cache_cont_start=out_cache_cont_start,
            out_cache_cont_end=out_cache_cont_end,
Cody Yu's avatar
Cody Yu committed
420
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
421
        )
Cody Yu's avatar
Cody Yu committed
422
        return self.model.forward(input_ids, input_metadata.positions, input_metadata)
Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
425
426
427
428

    @torch.inference_mode()
    def forward_extend_multi_modal(
        self,
        input_ids,
        pixel_values,
shiyi.c_98's avatar
shiyi.c_98 committed
429
        image_sizes,
Lianmin Zheng's avatar
Lianmin Zheng committed
430
431
432
433
434
435
        image_offsets,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
436
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
437
438
439
440
441
442
443
444
445
446
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
447
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
448
449
450
451
452
453
        )
        return self.model.forward(
            input_ids,
            input_metadata.positions,
            input_metadata,
            pixel_values,
shiyi.c_98's avatar
shiyi.c_98 committed
454
            image_sizes,
Lianmin Zheng's avatar
Lianmin Zheng committed
455
456
457
            image_offsets,
        )

458
    def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
Lianmin Zheng's avatar
Lianmin Zheng committed
459
460
461
462
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
            kwargs = {
                "input_ids": batch.input_ids,
                "pixel_values": batch.pixel_values,
shiyi.c_98's avatar
shiyi.c_98 committed
463
                "image_sizes": batch.image_sizes,
Lianmin Zheng's avatar
Lianmin Zheng committed
464
465
466
467
468
469
                "image_offsets": batch.image_offsets,
                "req_pool_indices": batch.req_pool_indices,
                "seq_lens": batch.seq_lens,
                "prefix_lens": batch.prefix_lens,
                "position_ids_offsets": batch.position_ids_offsets,
                "out_cache_loc": batch.out_cache_loc,
Cody Yu's avatar
Cody Yu committed
470
                "return_logprob": return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
471
472
473
474
475
476
477
478
479
480
            }
            return self.forward_extend_multi_modal(**kwargs)
        else:
            kwargs = {
                "input_ids": batch.input_ids,
                "req_pool_indices": batch.req_pool_indices,
                "seq_lens": batch.seq_lens,
                "prefix_lens": batch.prefix_lens,
                "position_ids_offsets": batch.position_ids_offsets,
                "out_cache_loc": batch.out_cache_loc,
Cody Yu's avatar
Cody Yu committed
481
                "return_logprob": return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
482
483
484
485
486
487
488
489
490
491
492
493
            }

        if forward_mode == ForwardMode.DECODE:
            kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
            kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
            return self.forward_decode(**kwargs)
        elif forward_mode == ForwardMode.EXTEND:
            return self.forward_extend(**kwargs)
        elif forward_mode == ForwardMode.PREFILL:
            return self.forward_prefill(**kwargs)
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")