model_runner.py 16.5 KB
Newer Older
Cody Yu's avatar
Cody Yu committed
1
import importlib
2
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
3
from dataclasses import dataclass
Cody Yu's avatar
Cody Yu committed
4
5
from functools import lru_cache
from pathlib import Path
Lianmin Zheng's avatar
Lianmin Zheng committed
6
7
8
9
from typing import List

import numpy as np
import torch
Cody Yu's avatar
Cody Yu committed
10
import sglang
Lianmin Zheng's avatar
Lianmin Zheng committed
11
12
13
14
15
16
17
18
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
from sglang.utils import get_available_gpu_memory
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel

Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
21
logger = logging.getLogger("model_runner")


Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
25
# for model_mode
global_model_mode: List[str] = []


Cody Yu's avatar
Cody Yu committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
        module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
        if hasattr(module, "EntryClass"):
            model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
    return model_arch_name_to_cls


def get_model_cls_by_arch_name(model_arch_names):
    model_arch_name_to_cls = import_model_classes()

    model_class = None
    for arch in model_arch_names:
        if arch in model_arch_name_to_cls:
            model_class = model_arch_name_to_cls[arch]
            break
    else:
        raise ValueError(
            f"Unsupported architectures: {arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_class


Lianmin Zheng's avatar
Lianmin Zheng committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@dataclass
class InputMetadata:
    model_runner: "ModelRunner"
    forward_mode: ForwardMode
    batch_size: int
    total_num_tokens: int
    max_seq_len: int
    req_pool_indices: torch.Tensor
    start_loc: torch.Tensor
    seq_lens: torch.Tensor
    prefix_lens: torch.Tensor
    positions: torch.Tensor
    req_to_token_pool: ReqToTokenPool
    token_to_kv_pool: TokenToKVPool

    # for extend
    extend_seq_lens: torch.Tensor = None
    extend_start_loc: torch.Tensor = None
    max_extend_len: int = 0

    out_cache_loc: torch.Tensor = None
    out_cache_cont_start: torch.Tensor = None
    out_cache_cont_end: torch.Tensor = None

    other_kv_index: torch.Tensor = None
77
    return_logprob: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

    # for flashinfer
    use_flashinfer: bool = False
    qo_indptr: torch.Tensor = None
    kv_indptr: torch.Tensor = None
    kv_indices: torch.Tensor = None
    kv_last_page_len: torch.Tensor = None
    prefill_wrapper = None
    decode_wrapper = None

    def init_flashinfer_args(self, tp_size):
        self.kv_indptr = torch.zeros(
            (self.batch_size + 1,), dtype=torch.int32, device="cuda"
        )
        self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
        self.kv_indices = torch.cat(
            [
                self.req_to_token_pool.req_to_token[
                    self.req_pool_indices[i].item(), : self.seq_lens[i].item()
                ]
                for i in range(self.batch_size)
            ],
            dim=0,
        ).contiguous()
        self.kv_last_page_len = torch.ones(
            (self.batch_size,), dtype=torch.int32, device="cuda"
        )

        from flashinfer.ops import (
            BatchDecodeWithPagedKVCacheWrapper,
            BatchPrefillWithPagedKVCacheWrapper,
        )

        if (
            self.forward_mode == ForwardMode.PREFILL
            or self.forward_mode == ForwardMode.EXTEND
        ):
            self.qo_indptr = torch.zeros(
                (self.batch_size + 1,), dtype=torch.int32, device="cuda"
            )
            self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
            self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
            self.prefill_wrapper.begin_forward(
                self.qo_indptr,
                self.batch_size,
                self.model_runner.model_config.num_attention_heads // tp_size,
                self.model_runner.model_config.num_key_value_heads // tp_size,
            )
        else:
            self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
            self.decode_wrapper.begin_forward(
                self.kv_indptr,
                self.kv_last_page_len,
                self.batch_size,
                self.model_runner.model_config.num_attention_heads // tp_size,
                self.model_runner.model_config.num_key_value_heads // tp_size,
                self.model_runner.model_config.head_dim,
                1,
                "NONE",
                "float16",
            )

    def init_extend_args(self):
        self.extend_seq_lens = self.seq_lens - self.prefix_lens
        self.extend_start_loc = torch.zeros_like(self.seq_lens)
143
        self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        self.max_extend_len = int(torch.max(self.extend_seq_lens))

    @classmethod
    def create(
        cls,
        model_runner,
        tp_size,
        forward_mode,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
        out_cache_cont_start=None,
        out_cache_cont_end=None,
159
        return_logprob=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    ):
        batch_size = len(req_pool_indices)
        start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
        start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
        total_num_tokens = int(torch.sum(seq_lens))
        max_seq_len = int(torch.max(seq_lens))

        if forward_mode == ForwardMode.DECODE:
            positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
            other_kv_index = model_runner.req_to_token_pool.req_to_token[
                req_pool_indices[0], seq_lens[0] - 1
            ].item()
        else:
            seq_lens_np = seq_lens.cpu().numpy()
            prefix_lens_np = prefix_lens.cpu().numpy()
            position_ids_offsets_np = position_ids_offsets.cpu().numpy()
            positions = torch.tensor(
                np.concatenate(
                    [
                        np.arange(
                            prefix_lens_np[i] + position_ids_offsets_np[i],
                            seq_lens_np[i] + position_ids_offsets_np[i],
                        )
                        for i in range(batch_size)
                    ],
                    axis=0,
                ),
                device="cuda",
            )
            other_kv_index = None

        ret = cls(
            model_runner=model_runner,
            forward_mode=forward_mode,
            batch_size=batch_size,
            total_num_tokens=total_num_tokens,
            max_seq_len=max_seq_len,
            req_pool_indices=req_pool_indices,
            start_loc=start_loc,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            positions=positions,
            req_to_token_pool=model_runner.req_to_token_pool,
            token_to_kv_pool=model_runner.token_to_kv_pool,
            out_cache_loc=out_cache_loc,
            out_cache_cont_start=out_cache_cont_start,
            out_cache_cont_end=out_cache_cont_end,
207
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
            other_kv_index=other_kv_index,
        )

        if forward_mode == ForwardMode.EXTEND:
            ret.init_extend_args()

        ret.use_flashinfer = "flashinfer" in model_runner.model_mode
        if ret.use_flashinfer:
            ret.init_flashinfer_args(tp_size)

        return ret


class ModelRunner:
    def __init__(
        self,
        model_config,
        mem_fraction_static,
        tp_rank,
        tp_size,
        nccl_port,
        load_format="auto",
        trust_remote_code=True,
        model_mode: List[str] = (),
    ):
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
        self.load_format = load_format
        self.trust_remote_code = trust_remote_code
        self.model_mode = model_mode

        global global_model_mode
        global_model_mode = model_mode

        # Init torch distributed
        torch.cuda.set_device(self.tp_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
            init_method=f"tcp://127.0.0.1:{self.nccl_port}",
        )

        # A small all_reduce for warmup.
        if self.tp_size > 1:
            torch.distributed.all_reduce(torch.zeros(1).cuda())
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)

        total_gpu_memory = get_available_gpu_memory(
            self.tp_rank, distributed=self.tp_size > 1
        ) * (1 << 30)
        self.load_model()
        self.init_memory_pool(total_gpu_memory)

        self.is_multimodal_model = is_multimodal_model(self.model_config)

    def load_model(self):
        """See also vllm/model_executor/model_loader.py::get_model"""
        # Select model class
        architectures = getattr(self.model_config.hf_config, "architectures", [])
Cody Yu's avatar
Cody Yu committed
271
        model_class = get_model_cls_by_arch_name(architectures)
272
        logger.info(f"Rank {self.tp_rank}: load weight begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
273

Lianmin Zheng's avatar
Lianmin Zheng committed
274
275
276
277
278
279
280
281
282
283
        # Load weights
        linear_method = None
        with _set_default_torch_dtype(torch.float16):
            with torch.device("cuda"):
                hf_quant_config = getattr(
                    self.model_config.hf_config, "quantization_config", None
                )
                if hf_quant_config is not None:
                    # TODO: config quantization awq etc
                    quant_config = AWQConfig.from_config(hf_quant_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
284
                    logger.info(f"quant_config: {quant_config}")
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
288
289
290
291
292
293
294
                    linear_method = quant_config.get_linear_method()
                model = model_class(
                    config=self.model_config.hf_config, linear_method=linear_method
                )
            model.load_weights(
                self.model_config.path,
                cache_dir=None,
                load_format=self.load_format,
                revision=None,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
295
        self.model = model.eval()
Lianmin Zheng's avatar
Lianmin Zheng committed
296

297
        logger.info(f"Rank {self.tp_rank}: load weight end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
298

Lianmin Zheng's avatar
Lianmin Zheng committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    def profile_max_num_token(self, total_gpu_memory):
        available_gpu_memory = get_available_gpu_memory(
            self.tp_rank, distributed=self.tp_size > 1
        ) * (1 << 30)
        head_dim = (
            self.model_config.hidden_size // self.model_config.num_attention_heads
        )
        head_num = self.model_config.num_key_value_heads // self.tp_size
        cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
        max_num_token = int(rest_memory // cell_size)
        return max_num_token

    def init_memory_pool(self, total_gpu_memory):
        self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
316
317

        if self.max_total_num_token <= 0:
318
319
320
            raise RuntimeError(
                "Not enought memory. " "Please try to increase --mem-fraction-static."
            )
321

Lianmin Zheng's avatar
Lianmin Zheng committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        self.req_to_token_pool = ReqToTokenPool(
            int(self.max_total_num_token / self.model_config.context_len * 256),
            self.model_config.context_len + 8,
        )
        self.token_to_kv_pool = TokenToKVPool(
            self.max_total_num_token,
            dtype=torch.float16,
            head_num=self.model_config.num_key_value_heads // self.tp_size,
            head_dim=self.model_config.hidden_size
            // self.model_config.num_attention_heads,
            layer_num=self.model_config.num_hidden_layers,
        )

    @torch.inference_mode()
    def forward_prefill(
        self,
        input_ids,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
344
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
345
346
347
348
349
350
351
352
353
354
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.PREFILL,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
355
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
356
357
358
359
360
361
362
363
364
365
366
367
        )
        return self.model.forward(input_ids, input_metadata.positions, input_metadata)

    @torch.inference_mode()
    def forward_extend(
        self,
        input_ids,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
368
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
369
370
371
372
373
374
375
376
377
378
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
379
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        )
        return self.model.forward(input_ids, input_metadata.positions, input_metadata)

    @torch.inference_mode()
    def forward_decode(
        self,
        input_ids,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
        out_cache_cont_start,
        out_cache_cont_end,
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.DECODE,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
            out_cache_cont_start=out_cache_cont_start,
            out_cache_cont_end=out_cache_cont_end,
        )
        return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
            0
        ]

    @torch.inference_mode()
    def forward_extend_multi_modal(
        self,
        input_ids,
        pixel_values,
shiyi.c_98's avatar
shiyi.c_98 committed
416
        image_sizes,
Lianmin Zheng's avatar
Lianmin Zheng committed
417
418
419
420
421
422
        image_offsets,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
423
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
424
425
426
427
428
429
430
431
432
433
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
434
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
437
438
439
440
        )
        return self.model.forward(
            input_ids,
            input_metadata.positions,
            input_metadata,
            pixel_values,
shiyi.c_98's avatar
shiyi.c_98 committed
441
            image_sizes,
Lianmin Zheng's avatar
Lianmin Zheng committed
442
443
444
            image_offsets,
        )

445
    def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
Lianmin Zheng's avatar
Lianmin Zheng committed
446
447
448
449
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
            kwargs = {
                "input_ids": batch.input_ids,
                "pixel_values": batch.pixel_values,
shiyi.c_98's avatar
shiyi.c_98 committed
450
                "image_sizes": batch.image_sizes,
Lianmin Zheng's avatar
Lianmin Zheng committed
451
452
453
454
455
456
457
                "image_offsets": batch.image_offsets,
                "req_pool_indices": batch.req_pool_indices,
                "seq_lens": batch.seq_lens,
                "prefix_lens": batch.prefix_lens,
                "position_ids_offsets": batch.position_ids_offsets,
                "out_cache_loc": batch.out_cache_loc,
            }
458
            kwargs["return_logprob"] = return_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
            return self.forward_extend_multi_modal(**kwargs)
        else:
            kwargs = {
                "input_ids": batch.input_ids,
                "req_pool_indices": batch.req_pool_indices,
                "seq_lens": batch.seq_lens,
                "prefix_lens": batch.prefix_lens,
                "position_ids_offsets": batch.position_ids_offsets,
                "out_cache_loc": batch.out_cache_loc,
            }

        if forward_mode == ForwardMode.DECODE:
            kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
            kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
            return self.forward_decode(**kwargs)
        elif forward_mode == ForwardMode.EXTEND:
475
            kwargs["return_logprob"] = return_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
476
477
            return self.forward_extend(**kwargs)
        elif forward_mode == ForwardMode.PREFILL:
478
            kwargs["return_logprob"] = return_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
479
480
481
            return self.forward_prefill(**kwargs)
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")