model_runner.py 8.71 KB
Newer Older
lizhigong's avatar
lizhigong committed
1
2
3
4
5
6
7
8
9
10
11
12


import torch
import itertools
from typing import List, Optional, Set
from vllm.lora.layers import LoRAMapping
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
lizhigong's avatar
lizhigong committed
13
from vllm.zero_overhead.sampler import get_last_sampler
lizhigong's avatar
lizhigong committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step

import triton
import triton.language as tl

@triton.jit
def _update_input_tokens(
    accepted_req_ids,
    accepted_req_ids_len,
    accepted_token_ids,
    accepted_token_len,
    chidren_req_ids,
    chidren_req_ids_len,
    input_tokens,
    input_tokens_len,
    input_positions,
    seq_lens,
    seq_lens_meta,
    seq_lens_tensor,
    slot_mapping,
    seq_start_loc,
    context_lens_tensor,
):
    chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len))
    accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len))
    
    for seq_id_idx in range(chidren_req_ids_len / 2):
        seq_id = chidren_req_ids_[2 * seq_id_idx]
        for i in range(accepted_req_ids_len):
            if seq_id == accepted_req_ids_[i]:
                accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len)))
                accepted_token_counter = 0
                for j in range(accepted_token_len):
                    if accepted_token_ids_[j] == -1:
                        break
                    accepted_token_counter += 1
                if accepted_token_counter == accepted_token_len:
                    tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:])
                else:
                    tl.store(input_tokens + seq_id_idx * 2, 0)
                    tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1])
                    input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2))
                    input_pos[0] = 0
                    input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter)
                    tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
                    tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
                    input_pos[0] = -1
                    tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
                    input_pos[0] = 1
                    input_pos[1] = input_pos[1] + 1
                    tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
                    tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
                    tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
    seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len))
    seq_start_loc_ = tl.zero_like(seq_start_loc)
    for i in range(input_tokens_len):
        seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i]
    tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_)

lizhigong's avatar
lizhigong committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95


class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
    def __init__(self, runner, finished_requests_ids = None):
        super().__init__(runner, finished_requests_ids)
        self.req_ids = []

    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        self.req_ids.clear()
        return super().prepare(finished_requests_ids)
    
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        seq_ids = seq_group_metadata.seq_data.keys()
        n_seqs = len(seq_ids)
        seq_ids = list(seq_ids)
        for seq_idx in range(n_seqs):
            self.req_ids.append(seq_ids[seq_idx])
        return super().add_seq_group(seq_group_metadata)

    def build(self) -> ModelInputForGPU:
        model_input = super().build()
        last_sampler = get_last_sampler()
lizhigong's avatar
lizhigong committed
96
97
        spec_step = get_spec_step()
        last_step = get_spec_last_step()
lizhigong's avatar
lizhigong committed
98
        if last_sampler is not None:
lizhigong's avatar
lizhigong committed
99
100
101
            if spec_step == SpecStepKind.KIND_DEFAULT:
                update_indices = []
                select_indices = []
102
                query_idx = 0
lizhigong's avatar
lizhigong committed
103
104
105
106
                for i, seq_id in enumerate(self.req_ids):
                    for j, seq_id_ in enumerate(last_sampler.seq_ids):
                        if seq_id == seq_id_:
                            select_indices.append(j)
107
108
109
110
                            update_indices.append(query_idx)
                            break     
                    query_idx += model_input.query_lens[i]
                if len(select_indices) > 0 and last_sampler.sampled_token_ids_tensor is not None:  
lizhigong's avatar
lizhigong committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
                    select_indices = async_tensor_h2d(select_indices, torch.long,
                                                    self.runner.device,
                                                    self.runner.pin_memory)
                    update_indices = async_tensor_h2d(update_indices, torch.long,
                                                    self.runner.device,
                                                    self.runner.pin_memory)       
                    model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
            if spec_step == SpecStepKind.OTHER_PROPOSAL:
                if last_step ==  SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly.
                    update_indices = [i for i in range(len(self.req_ids))]
                    update_indices = async_tensor_h2d(update_indices, torch.long,
                                                    self.runner.device,
                                                    self.runner.pin_memory)    
                    model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
                if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request.
                    update_indices = [i for i in range(len(self.req_ids))]
                    update_indices = async_tensor_h2d(update_indices, torch.long,
                                                    self.runner.device,
                                                    self.runner.pin_memory)    
                    model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
                    
        if spec_step == SpecStepKind.SCORE_DECODE:
            proposal_token_ids = get_proposal_token_ids()
            shape = proposal_token_ids.shape
            batch_size = shape[0]
            proposal_len = shape[1]
137
            update_indices = []
lizhigong's avatar
lizhigong committed
138
139
140
141
            for i in range(batch_size):
                for j in range(proposal_len): 
                    update_indices.append(i * (proposal_len + 1) + j + 1)

142
            update_indices = async_tensor_h2d(update_indices, torch.long,
lizhigong's avatar
lizhigong committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
                                            self.runner.device,
                                            self.runner.pin_memory)
            model_input.input_tokens[update_indices] = proposal_token_ids.view(-1)
        if spec_step == SpecStepKind.FIRST_PROPOSAL:
            if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
                pass
            if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
                accept_token_ids, accept_seq_ids = get_accepted_token_ids()

                chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long,
                                                self.runner.device,
                                                self.runner.pin_memory)    
                grid = [1, 1, 1]
                _update_input_tokens[grid](
                    accept_seq_ids, accept_seq_ids.shape[0], 
                    accept_token_ids, accept_token_ids.shape[1],
                    chidren_req_ids, chidren_req_ids.shape[0],
                    model_input.input_tokens, model_input.input_tokens.shape[0],
                    model_input.input_positions, 
                    model_input.seq_lens,
                    model_input.attn_metadata.seq_lens_tensor, 
                    model_input.attn_metadata.seq_lens,
                    model_input.attn_metadata.slot_mapping,
                    model_input.attn_metadata.seq_start_loc,
                    model_input.attn_metadata.context_lens_tensor,
                    )
            

lizhigong's avatar
lizhigong committed
171
        return model_input