model_runner.py 2.41 KB
Newer Older
lizhigong's avatar
lizhigong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


import torch
import itertools
from typing import List, Optional, Set
from vllm.lora.layers import LoRAMapping
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.v0.sampler import get_last_sampler
from vllm.zero_overhead.v0.update_input import UpdateInputTokens


class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
    def __init__(self, runner, finished_requests_ids = None):
        super().__init__(runner, finished_requests_ids)
        self.req_ids = []

    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        self.req_ids.clear()
        return super().prepare(finished_requests_ids)
    
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        seq_ids = seq_group_metadata.seq_data.keys()
        n_seqs = len(seq_ids)
        seq_ids = list(seq_ids)
        for seq_idx in range(n_seqs):
            self.req_ids.append(seq_ids[seq_idx])
        return super().add_seq_group(seq_group_metadata)

    def build(self) -> ModelInputForGPU:
        model_input = super().build()
        last_sampler = get_last_sampler()
lizhigong's avatar
lizhigong committed
38
        if last_sampler is not None:
39
40
41
42
43
44
45
46
47
            update_indices = []
            select_indices = []
            for i, seq_id in enumerate(self.req_ids):
                for j, seq_id_ in enumerate(last_sampler.seq_ids):
                    if seq_id == seq_id_:
                        select_indices.append(j)
                        update_indices.append(i)
                        break      
            select_indices = async_tensor_h2d(select_indices, torch.long,
lizhigong's avatar
lizhigong committed
48
49
                                               self.runner.device,
                                               self.runner.pin_memory)
50
            update_indices = async_tensor_h2d(update_indices, torch.long,
lizhigong's avatar
lizhigong committed
51
52
                                               self.runner.device,
                                               self.runner.pin_memory)
53
54
            if len(select_indices) > 0:          
                model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
lizhigong's avatar
lizhigong committed
55
        return model_input