model_runner.py 2.01 KB
Newer Older
lizhigong's avatar
lizhigong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


import torch
import itertools
from typing import List, Optional, Set
from vllm.lora.layers import LoRAMapping
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.v0.sampler import get_last_sampler
from vllm.zero_overhead.v0.update_input import UpdateInputTokens


class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
    def __init__(self, runner, finished_requests_ids = None):
        super().__init__(runner, finished_requests_ids)
        self.req_ids = []

    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        self.req_ids.clear()
        return super().prepare(finished_requests_ids)
    
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        seq_ids = seq_group_metadata.seq_data.keys()
        n_seqs = len(seq_ids)
        seq_ids = list(seq_ids)
        for seq_idx in range(n_seqs):
            self.req_ids.append(seq_ids[seq_idx])
        return super().add_seq_group(seq_group_metadata)

    def build(self) -> ModelInputForGPU:
        model_input = super().build()
        last_sampler = get_last_sampler()
lizhigong's avatar
lizhigong committed
38
        if last_sampler is not None:
lizhigong's avatar
lizhigong committed
39
40
41
42
43
44
45
46
            input_ids = async_tensor_h2d(self.req_ids, torch.long,
                                               self.runner.device,
                                               self.runner.pin_memory)
            last_ids = async_tensor_h2d(last_sampler.seq_id.tolist(), torch.long,
                                               self.runner.device,
                                               self.runner.pin_memory)
            UpdateInputTokens(model_input.input_tokens, input_ids, last_sampler.sampled_token_ids_tensor, last_ids)
        return model_input