utils.py 7.19 KB
Newer Older
1
import time
2
3
4
from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
5
6

from vllm import SamplingParams
7
from vllm.lora.request import LoRARequest
8
from vllm.sequence import Logprob, Sequence, SequenceGroup
9
10
11


def create_dummy_prompt(
12
13
14
15
16
17
    request_id: str,
    prompt_length: int,
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
    use_beam_search: bool = False,
    best_of: int = 1,
18
    prompt_tokens: Optional[List[int]] = None,
19
20
    min_tokens: int = 0,
    max_tokens: int = 16,
21
) -> Tuple[Sequence, SequenceGroup]:
22
23
24
    if not block_size:
        block_size = prompt_length

25
26
27
28
    if prompt_tokens is None:
        # Create dummy prompt sequence with tokens 0...block_size-1
        # and prompt "0 ... block_size".
        prompt_tokens = list(range(prompt_length))
29
    prompt_str = " ".join([str(t) for t in prompt_tokens])
30
31
32
33
34
35
    prompt = Sequence(int(request_id),
                      inputs={
                          "prompt": prompt_str,
                          "prompt_token_ids": prompt_tokens,
                      },
                      block_size=block_size)
36
37
38
39
40
    seq_group = SequenceGroup(request_id=request_id,
                              seqs=[prompt],
                              arrival_time=time.time(),
                              sampling_params=SamplingParams(
                                  use_beam_search=use_beam_search,
41
42
43
                                  best_of=best_of,
                                  max_tokens=max_tokens,
                                  min_tokens=min_tokens),
44
                              lora_request=lora_request)
45
46
47
48

    return prompt, seq_group


49
50
51
52
53
54
55
56
def create_dummy_prompt_encoder_decoder(
    request_id: str,
    decoder_prompt_length: int,
    encoder_prompt_length: int,
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
    use_beam_search: bool = False,
    best_of: int = 1,
57
) -> Tuple[Sequence, Sequence, SequenceGroup]:
58
59
60
61
    if not block_size:
        block_size = decoder_prompt_length

    # Create dummy prompt sequence with tokens 0...block_size-1
62
63
    # and prompt "0 ... block_size". Note that the prompt string
    # doesn't actually match the tokens
64
65
    decoder_prompt_tokens = list(range(decoder_prompt_length))
    decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
66
67
68
69
70
71
72
73
74
75
    encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
    encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])

    inputs = {
        "prompt": decoder_prompt_str,
        "prompt_token_ids": decoder_prompt_tokens,
        "encoder_prompt": encoder_prompt_str,
        "encoder_prompt_token_ids": encoder_prompt_tokens,
        "multi_modal_data": None,
    }
76
77

    decoder_prompt = Sequence(int(request_id),
78
79
80
                              inputs=inputs,
                              block_size=block_size,
                              from_decoder_prompt=True)
81
82

    encoder_prompt = Sequence(int(request_id),
83
84
85
                              inputs=inputs,
                              block_size=block_size,
                              from_decoder_prompt=False)
86
87
88
89
90
91
92
93
94
95
96
97
    seq_group = SequenceGroup(request_id=request_id,
                              seqs=[decoder_prompt],
                              sampling_params=SamplingParams(
                                  use_beam_search=use_beam_search,
                                  best_of=best_of),
                              arrival_time=time.time(),
                              lora_request=lora_request,
                              encoder_seq=encoder_prompt)

    return decoder_prompt, encoder_prompt, seq_group


98
def create_seq_group(
99
        seq_prompt_len: int = 1024,
100
        seq_output_lens: GenericSequence[int] = (128, ),
101
102
103
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
104
105
106

    assert len(seq_output_lens) > 0

107
108
109
    if sampling_params is None:
        sampling_params = SamplingParams()

110
    prompt_token_ids = [0] * seq_prompt_len
111

112
    seqs: List[Sequence] = []
113
114
115
    for seq_id_offset, output_len in enumerate(seq_output_lens):
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
116
            inputs={"prompt_token_ids": prompt_token_ids},
117
118
119
120
121
122
123
124
125
126
127
128
129
            block_size=16,
        )

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

    seq_group = SequenceGroup(
        request_id=request_id,
        seqs=seqs,
130
        sampling_params=sampling_params,
131
132
133
134
135
136
        arrival_time=time.time(),
    )

    return seq_group


137
138
def create_seq_group_encoder_decoder(
        seq_prompt_len: int = 1024,
139
        seq_output_lens: GenericSequence[int] = (128, ),
140
141
142
143
144
145
146
147
148
149
150
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:

    assert len(seq_output_lens) > 0

    if sampling_params is None:
        sampling_params = SamplingParams()

    prompt_token_ids = [0] * seq_prompt_len

151
152
153
154
155
156
157
158
    inputs = {
        "prompt": "",
        "prompt_token_ids": prompt_token_ids,
        "encoder_prompt": "",
        "encoder_prompt_token_ids": prompt_token_ids,
        "multi_modal_data": None,
    }

159
160
    seqs = []
    for seq_id_offset, output_len in enumerate(seq_output_lens):
161
162
163
164
165
        # Construct decoder input sequences
        seq = Sequence(seq_id=seq_id_start + seq_id_offset,
                       inputs=inputs,
                       block_size=16,
                       from_decoder_prompt=True)
166
167
168
169
170
171
172
173

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

174
175
176
177
178
    # Encoder input sequence
    encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
                           inputs=inputs,
                           block_size=16,
                           from_decoder_prompt=False)
179
180
181
182
183
184
185
186

    return SequenceGroup(request_id=request_id,
                         seqs=seqs,
                         sampling_params=sampling_params,
                         arrival_time=time.time(),
                         encoder_seq=encoder_seq)


187
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    return (seq_len + block_size - 1) // block_size


# Helper functions for scheduler tests


def get_sequence_groups(scheduler_output):
    return [s.seq_group for s in scheduler_output.scheduled_seq_groups]


def append_new_token(out, token_id: int):
    seq_groups = get_sequence_groups(out)
    for seq_group in seq_groups:
        for seq in seq_group.get_seqs():
            seq.append_token_id(token_id, {token_id: Logprob(token_id)})


def schedule_and_update_computed_tokens(scheduler):
206
    metas, out, _ = scheduler.schedule()
207
208
209
210
211
212
213
214
215
    for s, meta in zip(out.scheduled_seq_groups, metas):
        s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
    return metas, out


def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
    seq_group.update_num_computed_tokens(token_chunk_size)
    for seq in seq_group.get_seqs():
        seq.append_token_id(token_id, {token_id: Logprob(token_id)})