model_runner.py 17 KB
Newer Older
Cody Yu's avatar
Cody Yu committed
1
import importlib
2
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
3
from dataclasses import dataclass
Cody Yu's avatar
Cody Yu committed
4
5
from functools import lru_cache
from pathlib import Path
Lianmin Zheng's avatar
Lianmin Zheng committed
6
7
8
9
10
11
12
13

import numpy as np
import torch
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
from sglang.utils import get_available_gpu_memory
from vllm.model_executor.layers.quantization.awq import AWQConfig
14
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
17
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel

18
import sglang
19
20

QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
21

Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
logger = logging.getLogger("model_runner")


Liangsheng Yin's avatar
Liangsheng Yin committed
25
# for server args in model endpoints
Liangsheng Yin's avatar
Liangsheng Yin committed
26
global_server_args_dict: dict = None
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28


Cody Yu's avatar
Cody Yu committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
        module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
        if hasattr(module, "EntryClass"):
            model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
    return model_arch_name_to_cls


def get_model_cls_by_arch_name(model_arch_names):
    model_arch_name_to_cls = import_model_classes()

    model_class = None
    for arch in model_arch_names:
        if arch in model_arch_name_to_cls:
            model_class = model_arch_name_to_cls[arch]
            break
    else:
        raise ValueError(
            f"Unsupported architectures: {arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_class


Lianmin Zheng's avatar
Lianmin Zheng committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@dataclass
class InputMetadata:
    model_runner: "ModelRunner"
    forward_mode: ForwardMode
    batch_size: int
    total_num_tokens: int
    max_seq_len: int
    req_pool_indices: torch.Tensor
    start_loc: torch.Tensor
    seq_lens: torch.Tensor
    prefix_lens: torch.Tensor
    positions: torch.Tensor
    req_to_token_pool: ReqToTokenPool
    token_to_kv_pool: TokenToKVPool

    # for extend
    extend_seq_lens: torch.Tensor = None
    extend_start_loc: torch.Tensor = None
    max_extend_len: int = 0

    out_cache_loc: torch.Tensor = None
    out_cache_cont_start: torch.Tensor = None
    out_cache_cont_end: torch.Tensor = None

    other_kv_index: torch.Tensor = None
80
    return_logprob: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
83
84
85
86
87
88
89
90

    # for flashinfer
    qo_indptr: torch.Tensor = None
    kv_indptr: torch.Tensor = None
    kv_indices: torch.Tensor = None
    kv_last_page_len: torch.Tensor = None
    prefill_wrapper = None
    decode_wrapper = None

    def init_flashinfer_args(self, tp_size):
91
92
93
94
95
        from flashinfer import (
            BatchDecodeWithPagedKVCacheWrapper,
            BatchPrefillWithPagedKVCacheWrapper,
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        self.kv_indptr = torch.zeros(
            (self.batch_size + 1,), dtype=torch.int32, device="cuda"
        )
        self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
        self.kv_indices = torch.cat(
            [
                self.req_to_token_pool.req_to_token[
                    self.req_pool_indices[i].item(), : self.seq_lens[i].item()
                ]
                for i in range(self.batch_size)
            ],
            dim=0,
        ).contiguous()
        self.kv_last_page_len = torch.ones(
            (self.batch_size,), dtype=torch.int32, device="cuda"
        )

113
114
115
        workspace_buffer = torch.empty(
            32 * 1024 * 1024, dtype=torch.int8, device="cuda"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
118
119
120
121
122
123
        if (
            self.forward_mode == ForwardMode.PREFILL
            or self.forward_mode == ForwardMode.EXTEND
        ):
            self.qo_indptr = torch.zeros(
                (self.batch_size + 1,), dtype=torch.int32, device="cuda"
            )
            self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
124
125
126
            self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                workspace_buffer, "NHD"
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
127
128
            self.prefill_wrapper.begin_forward(
                self.qo_indptr,
129
130
131
                self.kv_indptr,
                self.kv_indices,
                self.kv_last_page_len,
Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
134
135
                self.model_runner.model_config.num_attention_heads // tp_size,
                self.model_runner.model_config.num_key_value_heads // tp_size,
            )
        else:
136
137
138
            self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
                workspace_buffer, "NHD"
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
139
140
            self.decode_wrapper.begin_forward(
                self.kv_indptr,
141
                self.kv_indices,
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
144
145
146
147
148
149
150
151
152
153
                self.kv_last_page_len,
                self.model_runner.model_config.num_attention_heads // tp_size,
                self.model_runner.model_config.num_key_value_heads // tp_size,
                self.model_runner.model_config.head_dim,
                1,
                "NONE",
                "float16",
            )

    def init_extend_args(self):
        self.extend_seq_lens = self.seq_lens - self.prefix_lens
        self.extend_start_loc = torch.zeros_like(self.seq_lens)
154
        self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
Lianmin Zheng's avatar
Lianmin Zheng committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        self.max_extend_len = int(torch.max(self.extend_seq_lens))

    @classmethod
    def create(
        cls,
        model_runner,
        tp_size,
        forward_mode,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
        out_cache_cont_start=None,
        out_cache_cont_end=None,
170
        return_logprob=False,
Lianmin Zheng's avatar
Lianmin Zheng committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    ):
        batch_size = len(req_pool_indices)
        start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
        start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
        total_num_tokens = int(torch.sum(seq_lens))
        max_seq_len = int(torch.max(seq_lens))

        if forward_mode == ForwardMode.DECODE:
            positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
            other_kv_index = model_runner.req_to_token_pool.req_to_token[
                req_pool_indices[0], seq_lens[0] - 1
            ].item()
        else:
            seq_lens_np = seq_lens.cpu().numpy()
            prefix_lens_np = prefix_lens.cpu().numpy()
            position_ids_offsets_np = position_ids_offsets.cpu().numpy()
            positions = torch.tensor(
                np.concatenate(
                    [
                        np.arange(
                            prefix_lens_np[i] + position_ids_offsets_np[i],
                            seq_lens_np[i] + position_ids_offsets_np[i],
                        )
                        for i in range(batch_size)
                    ],
                    axis=0,
                ),
                device="cuda",
            )
            other_kv_index = None

        ret = cls(
            model_runner=model_runner,
            forward_mode=forward_mode,
            batch_size=batch_size,
            total_num_tokens=total_num_tokens,
            max_seq_len=max_seq_len,
            req_pool_indices=req_pool_indices,
            start_loc=start_loc,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            positions=positions,
            req_to_token_pool=model_runner.req_to_token_pool,
            token_to_kv_pool=model_runner.token_to_kv_pool,
            out_cache_loc=out_cache_loc,
            out_cache_cont_start=out_cache_cont_start,
            out_cache_cont_end=out_cache_cont_end,
218
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
219
220
221
222
223
224
            other_kv_index=other_kv_index,
        )

        if forward_mode == ForwardMode.EXTEND:
            ret.init_extend_args()

Liangsheng Yin's avatar
Liangsheng Yin committed
225
        if global_server_args_dict["enable_flashinfer"]:
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
            ret.init_flashinfer_args(tp_size)

        return ret


class ModelRunner:
    def __init__(
        self,
        model_config,
        mem_fraction_static,
        tp_rank,
        tp_size,
        nccl_port,
        load_format="auto",
        trust_remote_code=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
241
        server_args_dict: dict = {},
Lianmin Zheng's avatar
Lianmin Zheng committed
242
243
244
245
246
247
248
249
250
    ):
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
        self.load_format = load_format
        self.trust_remote_code = trust_remote_code

Liangsheng Yin's avatar
Liangsheng Yin committed
251
252
        global global_server_args_dict
        global_server_args_dict = server_args_dict
Lianmin Zheng's avatar
Lianmin Zheng committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

        # Init torch distributed
        torch.cuda.set_device(self.tp_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
            init_method=f"tcp://127.0.0.1:{self.nccl_port}",
        )

        # A small all_reduce for warmup.
        if self.tp_size > 1:
            torch.distributed.all_reduce(torch.zeros(1).cuda())
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)

        total_gpu_memory = get_available_gpu_memory(
            self.tp_rank, distributed=self.tp_size > 1
        ) * (1 << 30)
        self.load_model()
        self.init_memory_pool(total_gpu_memory)

        self.is_multimodal_model = is_multimodal_model(self.model_config)

    def load_model(self):
        """See also vllm/model_executor/model_loader.py::get_model"""
        # Select model class
        architectures = getattr(self.model_config.hf_config, "architectures", [])
Cody Yu's avatar
Cody Yu committed
280
        model_class = get_model_cls_by_arch_name(architectures)
281
        logger.info(f"Rank {self.tp_rank}: load weight begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
282

Lianmin Zheng's avatar
Lianmin Zheng committed
283
284
285
286
287
288
289
290
        # Load weights
        linear_method = None
        with _set_default_torch_dtype(torch.float16):
            with torch.device("cuda"):
                hf_quant_config = getattr(
                    self.model_config.hf_config, "quantization_config", None
                )
                if hf_quant_config is not None:
291
292
293
                    quant_config_class = QUANTIONCONFIG_MAPPING.get(
                        hf_quant_config["quant_method"]
                    )
294
                    if quant_config_class is None:
295
296
297
                        raise ValueError(
                            f"Unsupported quantization method: {hf_quant_config['quant_method']}"
                        )
298
                    quant_config = quant_config_class.from_config(hf_quant_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
299
                    logger.info(f"quant_config: {quant_config}")
Lianmin Zheng's avatar
Lianmin Zheng committed
300
301
302
303
304
305
306
307
308
309
                    linear_method = quant_config.get_linear_method()
                model = model_class(
                    config=self.model_config.hf_config, linear_method=linear_method
                )
            model.load_weights(
                self.model_config.path,
                cache_dir=None,
                load_format=self.load_format,
                revision=None,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
310
        self.model = model.eval()
Lianmin Zheng's avatar
Lianmin Zheng committed
311

312
        logger.info(f"Rank {self.tp_rank}: load weight end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
313

Lianmin Zheng's avatar
Lianmin Zheng committed
314
315
316
317
    def profile_max_num_token(self, total_gpu_memory):
        available_gpu_memory = get_available_gpu_memory(
            self.tp_rank, distributed=self.tp_size > 1
        ) * (1 << 30)
Liangsheng Yin's avatar
Liangsheng Yin committed
318
        head_dim = self.model_config.head_dim
Lianmin Zheng's avatar
Lianmin Zheng committed
319
320
321
322
323
324
325
326
327
328
        head_num = self.model_config.num_key_value_heads // self.tp_size
        cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
        max_num_token = int(rest_memory // cell_size)
        return max_num_token

    def init_memory_pool(self, total_gpu_memory):
        self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
329
330

        if self.max_total_num_token <= 0:
331
332
333
            raise RuntimeError(
                "Not enought memory. " "Please try to increase --mem-fraction-static."
            )
334

Lianmin Zheng's avatar
Lianmin Zheng committed
335
336
337
338
339
340
341
342
        self.req_to_token_pool = ReqToTokenPool(
            int(self.max_total_num_token / self.model_config.context_len * 256),
            self.model_config.context_len + 8,
        )
        self.token_to_kv_pool = TokenToKVPool(
            self.max_total_num_token,
            dtype=torch.float16,
            head_num=self.model_config.num_key_value_heads // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
343
            head_dim=self.model_config.head_dim,
Lianmin Zheng's avatar
Lianmin Zheng committed
344
345
346
347
348
349
350
351
352
353
354
355
            layer_num=self.model_config.num_hidden_layers,
        )

    @torch.inference_mode()
    def forward_prefill(
        self,
        input_ids,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
356
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
359
360
361
362
363
364
365
366
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.PREFILL,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
367
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
368
369
370
371
372
373
374
375
376
377
378
379
        )
        return self.model.forward(input_ids, input_metadata.positions, input_metadata)

    @torch.inference_mode()
    def forward_extend(
        self,
        input_ids,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
380
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
383
384
385
386
387
388
389
390
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
391
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        )
        return self.model.forward(input_ids, input_metadata.positions, input_metadata)

    @torch.inference_mode()
    def forward_decode(
        self,
        input_ids,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
        out_cache_cont_start,
        out_cache_cont_end,
Cody Yu's avatar
Cody Yu committed
406
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
407
408
409
410
411
412
413
414
415
416
417
418
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.DECODE,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
            out_cache_cont_start=out_cache_cont_start,
            out_cache_cont_end=out_cache_cont_end,
Cody Yu's avatar
Cody Yu committed
419
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
420
        )
Cody Yu's avatar
Cody Yu committed
421
        return self.model.forward(input_ids, input_metadata.positions, input_metadata)
Lianmin Zheng's avatar
Lianmin Zheng committed
422
423
424
425
426
427

    @torch.inference_mode()
    def forward_extend_multi_modal(
        self,
        input_ids,
        pixel_values,
shiyi.c_98's avatar
shiyi.c_98 committed
428
        image_sizes,
Lianmin Zheng's avatar
Lianmin Zheng committed
429
430
431
432
433
434
        image_offsets,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
435
        return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
436
437
438
439
440
441
442
443
444
445
    ):
        input_metadata = InputMetadata.create(
            self,
            forward_mode=ForwardMode.EXTEND,
            tp_size=self.tp_size,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            position_ids_offsets=position_ids_offsets,
            out_cache_loc=out_cache_loc,
446
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
447
448
449
450
451
452
        )
        return self.model.forward(
            input_ids,
            input_metadata.positions,
            input_metadata,
            pixel_values,
shiyi.c_98's avatar
shiyi.c_98 committed
453
            image_sizes,
Lianmin Zheng's avatar
Lianmin Zheng committed
454
455
456
            image_offsets,
        )

457
    def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
Lianmin Zheng's avatar
Lianmin Zheng committed
458
459
460
461
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
            kwargs = {
                "input_ids": batch.input_ids,
                "pixel_values": batch.pixel_values,
shiyi.c_98's avatar
shiyi.c_98 committed
462
                "image_sizes": batch.image_sizes,
Lianmin Zheng's avatar
Lianmin Zheng committed
463
464
465
466
467
468
                "image_offsets": batch.image_offsets,
                "req_pool_indices": batch.req_pool_indices,
                "seq_lens": batch.seq_lens,
                "prefix_lens": batch.prefix_lens,
                "position_ids_offsets": batch.position_ids_offsets,
                "out_cache_loc": batch.out_cache_loc,
Cody Yu's avatar
Cody Yu committed
469
                "return_logprob": return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
470
471
472
473
474
475
476
477
478
479
            }
            return self.forward_extend_multi_modal(**kwargs)
        else:
            kwargs = {
                "input_ids": batch.input_ids,
                "req_pool_indices": batch.req_pool_indices,
                "seq_lens": batch.seq_lens,
                "prefix_lens": batch.prefix_lens,
                "position_ids_offsets": batch.position_ids_offsets,
                "out_cache_loc": batch.out_cache_loc,
Cody Yu's avatar
Cody Yu committed
480
                "return_logprob": return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
481
482
483
484
485
486
487
488
489
490
491
492
            }

        if forward_mode == ForwardMode.DECODE:
            kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
            kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
            return self.forward_decode(**kwargs)
        elif forward_mode == ForwardMode.EXTEND:
            return self.forward_extend(**kwargs)
        elif forward_mode == ForwardMode.PREFILL:
            return self.forward_prefill(**kwargs)
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")