model_runner.py 2.4 KB
Newer Older
lizhigong's avatar
lizhigong committed
1
2
3
4
5
6
7
8
9
10
11
12


import torch
import itertools
from typing import List, Optional, Set
from vllm.lora.layers import LoRAMapping
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
lizhigong's avatar
lizhigong committed
13
14
from vllm.zero_overhead.sampler import get_last_sampler
from vllm.zero_overhead.update_input import UpdateInputTokens
lizhigong's avatar
lizhigong committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
    def __init__(self, runner, finished_requests_ids = None):
        super().__init__(runner, finished_requests_ids)
        self.req_ids = []

    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        self.req_ids.clear()
        return super().prepare(finished_requests_ids)
    
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        seq_ids = seq_group_metadata.seq_data.keys()
        n_seqs = len(seq_ids)
        seq_ids = list(seq_ids)
        for seq_idx in range(n_seqs):
            self.req_ids.append(seq_ids[seq_idx])
        return super().add_seq_group(seq_group_metadata)

    def build(self) -> ModelInputForGPU:
        model_input = super().build()
        last_sampler = get_last_sampler()
lizhigong's avatar
lizhigong committed
38
        if last_sampler is not None:
39
40
41
42
43
44
45
46
47
            update_indices = []
            select_indices = []
            for i, seq_id in enumerate(self.req_ids):
                for j, seq_id_ in enumerate(last_sampler.seq_ids):
                    if seq_id == seq_id_:
                        select_indices.append(j)
                        update_indices.append(i)
                        break      
            select_indices = async_tensor_h2d(select_indices, torch.long,
lizhigong's avatar
lizhigong committed
48
49
                                               self.runner.device,
                                               self.runner.pin_memory)
50
            update_indices = async_tensor_h2d(update_indices, torch.long,
lizhigong's avatar
lizhigong committed
51
52
                                               self.runner.device,
                                               self.runner.pin_memory)
53
54
            if len(select_indices) > 0:          
                model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
lizhigong's avatar
lizhigong committed
55
        return model_input