utils.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from collections.abc import Sequence as GenericSequence
7
8
from itertools import count
from typing import Any, Optional, Union
9

10
11
import torch

12
from vllm.core.scheduler import Scheduler, SchedulerOutputs
13
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
14
from vllm.lora.request import LoRARequest
15
16
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, Sequence, SequenceData, SequenceGroup,
17
                           SequenceGroupMetadata)
18
19
20


def create_dummy_prompt(
21
    request_id: str,
22
    prompt_length: int = -1,
23
24
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
25
    prompt_tokens: Optional[list[int]] = None,
26
    prompt_embeds: Optional[torch.Tensor] = None,
27
28
    min_tokens: int = 0,
    max_tokens: int = 16,
29
) -> tuple[Sequence, SequenceGroup]:
30
31
32
    if not block_size:
        block_size = prompt_length

33
34
35
36
    if prompt_tokens is None:
        # Create dummy prompt sequence with tokens 0...block_size-1
        # and prompt "0 ... block_size".
        prompt_tokens = list(range(prompt_length))
37

38
    prompt_str = " ".join([str(t) for t in prompt_tokens])
39
40
41
42
    inputs = token_inputs(
        prompt_token_ids=prompt_tokens,
        prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
            prompt_embeds=prompt_embeds)
43
44
    prompt = Sequence(
        int(request_id),
45
        inputs=inputs,
46
47
48
49
50
51
52
53
54
55
        block_size=block_size,
    )
    seq_group = SequenceGroup(
        request_id=request_id,
        seqs=[prompt],
        arrival_time=time.time(),
        sampling_params=SamplingParams(max_tokens=max_tokens,
                                       min_tokens=min_tokens),
        lora_request=lora_request,
    )
56
57
58
59

    return prompt, seq_group


60
def create_dummy_lora_sequence(request_id: int, token_ids: list[int],
61
62
63
64
65
66
67
68
69
                               block_size: int, lora_int_id: int) -> Sequence:
    return Sequence(seq_id=request_id,
                    inputs=token_inputs(token_ids),
                    block_size=block_size,
                    lora_request=LoRARequest(lora_name="dummy",
                                             lora_path="/dummy",
                                             lora_int_id=lora_int_id))


70
def create_dummy_sequence(request_id: int, token_ids: list[int],
71
72
73
74
75
76
77
78
                          block_size: int) -> Sequence:
    return Sequence(
        seq_id=request_id,
        inputs=token_inputs(token_ids),
        block_size=block_size,
    )


79
80
81
82
83
84
def create_dummy_prompt_encoder_decoder(
    request_id: str,
    decoder_prompt_length: int,
    encoder_prompt_length: int,
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
85
) -> tuple[Sequence, Sequence, SequenceGroup]:
86
87
88
89
    if not block_size:
        block_size = decoder_prompt_length

    # Create dummy prompt sequence with tokens 0...block_size-1
90
91
    # and prompt "0 ... block_size". Note that the prompt string
    # doesn't actually match the tokens
92
93
    decoder_prompt_tokens = list(range(decoder_prompt_length))
    decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
94
95
96
    encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
    encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])

97
98
99
100
101
    inputs: EncoderDecoderInputs = {
        "decoder": token_inputs(decoder_prompt_tokens,
                                prompt=decoder_prompt_str),
        "encoder": token_inputs(encoder_prompt_tokens,
                                prompt=encoder_prompt_str),
102
    }
103
104

    decoder_prompt = Sequence(int(request_id),
105
106
                              inputs=inputs["decoder"],
                              block_size=block_size)
107
108

    encoder_prompt = Sequence(int(request_id),
109
110
111
                              inputs=inputs["encoder"],
                              block_size=block_size)

112
113
114
115
116
117
118
119
120
    seq_group = SequenceGroup(request_id=request_id,
                              seqs=[decoder_prompt],
                              arrival_time=time.time(),
                              lora_request=lora_request,
                              encoder_seq=encoder_prompt)

    return decoder_prompt, encoder_prompt, seq_group


121
def create_seq_group(
122
        seq_prompt_len: int = 1024,
123
        seq_output_lens: GenericSequence[int] = (128, ),
124
125
126
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
127
128
129

    assert len(seq_output_lens) > 0

130
131
132
    if sampling_params is None:
        sampling_params = SamplingParams()

133
    prompt_token_ids = [0] * seq_prompt_len
134

135
    seqs: list[Sequence] = []
136
137
138
    for seq_id_offset, output_len in enumerate(seq_output_lens):
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
139
            inputs=token_inputs(prompt_token_ids),
140
141
142
143
144
145
146
147
148
149
150
151
152
            block_size=16,
        )

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

    seq_group = SequenceGroup(
        request_id=request_id,
        seqs=seqs,
153
        sampling_params=sampling_params,
154
155
156
157
158
159
        arrival_time=time.time(),
    )

    return seq_group


160
161
def create_seq_group_encoder_decoder(
        seq_prompt_len: int = 1024,
162
        seq_output_lens: GenericSequence[int] = (128, ),
163
164
165
166
167
168
169
170
171
172
173
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:

    assert len(seq_output_lens) > 0

    if sampling_params is None:
        sampling_params = SamplingParams()

    prompt_token_ids = [0] * seq_prompt_len

174
175
176
    inputs: EncoderDecoderInputs = {
        "decoder": token_inputs(prompt_token_ids),
        "encoder": token_inputs(prompt_token_ids),
177
178
    }

179
180
    seqs = []
    for seq_id_offset, output_len in enumerate(seq_output_lens):
181
        # Construct decoder input sequences
182
183
184
185
186
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
            inputs=inputs["decoder"],
            block_size=16,
        )
187
188
189
190
191
192
193
194

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

195
    # Encoder input sequence
196
197
198
199
200
    encoder_seq = Sequence(
        seq_id=seq_id_start + len(seq_output_lens),
        inputs=inputs["encoder"],
        block_size=16,
    )
201
202
203
204
205
206
207
208

    return SequenceGroup(request_id=request_id,
                         seqs=seqs,
                         sampling_params=sampling_params,
                         arrival_time=time.time(),
                         encoder_seq=encoder_seq)


209
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    return (seq_len + block_size - 1) // block_size


# Helper functions for scheduler tests


def get_sequence_groups(scheduler_output):
    return [s.seq_group for s in scheduler_output.scheduled_seq_groups]


def append_new_token(out, token_id: int):
    seq_groups = get_sequence_groups(out)
    for seq_group in seq_groups:
        for seq in seq_group.get_seqs():
            seq.append_token_id(token_id, {token_id: Logprob(token_id)})


def schedule_and_update_computed_tokens(scheduler):
228
    metas, out, _ = scheduler.schedule()
229
230
    for s in out.scheduled_seq_groups:
        s.seq_group.update_num_computed_tokens(s.token_chunk_size)
231
232
233
    return metas, out


234
235
236
237
def append_new_token_seq(seq: Sequence, token_id: int):
    seq.append_token_id(token_id, {token_id: Logprob(token_id)})


238
239
240
241
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
    seq_group.update_num_computed_tokens(token_chunk_size)
    for seq in seq_group.get_seqs():
        seq.append_token_id(token_id, {token_id: Logprob(token_id)})
242
243
244
245
246
247
248
249
250


class SchedulerProxy:
    """
    A proxy class to forward calls to the scheduler.
    """

    def __init__(self, scheduler: Scheduler):
        self.scheduler_ = scheduler
251
        self.call_history: dict[str, list[Any]] = defaultdict(list)
252
253
254
255
256
257
258
259
260
261
262

    def __getattr__(self, name: str) -> Any:

        def wrapper(*args, **kwargs):
            result = getattr(self.scheduler_, name)(*args, **kwargs)
            self.call_history[name].append((args, kwargs, result))
            return result

        return wrapper

    def last_schedule_ret(
263
        self, ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]:
264
265
        _, _, ret = self.call_history["schedule"][-1]
        return ret
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392


def create_seq_group_metadata_from_prompts(
    prompts: list[list[int]],
    num_gpu_blocks: int,
    block_size: int,
    final_prompt_lens: list[int],
    continuations: Optional[list[list[int]]] = None,
    seq_ids: Optional[list[int]] = None,
) -> list[SequenceGroupMetadata]:

    if continuations is None:
        continuations = [[] for _ in prompts]

    if seq_ids is None:
        seq_ids = list(i for i, _ in enumerate(prompts))

    free_gpu_blocks = list(range(num_gpu_blocks))

    block_allocations = {
        i: [
            free_gpu_blocks.pop()
            for _ in range(round_up_to_next_block(final_len, block_size))
        ]
        for i, final_len in enumerate(final_prompt_lens)
    }

    seq_grou_metadata_list = []
    for i, (prompt_token_ids,
            cont_token_ids) in enumerate(zip(prompts, continuations)):
        data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
        data.update_num_computed_tokens(
            len(prompt_token_ids) + len(cont_token_ids) - 1)
        seq_data = {i: data}
        seq_grou_metadata_list.append(
            SequenceGroupMetadata(
                request_id=str(i),
                is_prompt=len(cont_token_ids) == 0,
                seq_data=seq_data,
                sampling_params=SamplingParams(temperature=0.0),
                block_tables={i: block_allocations[i][:]},
            ))
    return seq_grou_metadata_list


def create_chunked_seq_group_metadata_from_prompt(
        prompt: list[int],
        num_gpu_blocks: int,
        chunk_size: int,
        block_size: int,
        seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:

    if seq_id is None:
        seq_id = 0

    free_gpu_blocks = list(range(num_gpu_blocks))

    block_allocations = [
        free_gpu_blocks.pop()
        for _ in range(round_up_to_next_block(len(prompt), block_size))
    ]

    seq_group_metadata_list = []
    for i, idx in enumerate(range(0, len(prompt), chunk_size)):
        chunk_ids = prompt[idx:idx + chunk_size]
        data = SequenceData.from_seqs(prompt)
        data.update_num_computed_tokens(idx)
        seq_data = {i: data}
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=str(seq_id),
                is_prompt=True,
                do_sample=idx + chunk_size >= len(prompt),  # terminal chunk
                seq_data=seq_data,
                sampling_params=SamplingParams(temperature=0.0),
                block_tables={i: block_allocations},
                token_chunk_size=len(chunk_ids)))
    return seq_group_metadata_list


def create_batch(batch_size,
                 k,
                 prompt_len: Union[int, list[int]] = 10,
                 prev_output_token_len: int = 10,
                 seq_ids: Optional[list[int]] = None,
                 num_gpu_blocks: Optional[int] = None,
                 block_size: Optional[int] = None,
                 prefill_chunk_size: Optional[int] = None):
    if block_size is None:
        block_size = 8

    if num_gpu_blocks is None:
        num_gpu_blocks = 2048 // block_size

    iterator = count()

    if isinstance(prompt_len, int):
        prompt_lens = [prompt_len for _ in range(batch_size)]
    else:
        prompt_lens = prompt_len

    prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]

    if prefill_chunk_size:
        # Create a batch of chunked prompts.
        if not seq_ids:
            seq_ids = list(range(len(prompts)))
        seq_group_metadata_list = []
        for p, sid in zip(prompts, seq_ids):
            seq_group_metadata_list += \
                create_chunked_seq_group_metadata_from_prompt(
                p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
        seq_group_metadata_list = seq_group_metadata_list[:batch_size]
        prev_output_tokens = []
    else:
        prev_output_tokens = [[
            next(iterator) for _ in range(prev_output_token_len)
        ] for _ in range(batch_size)]
        final_prompt_lens = [
            len(prompt) + len(prev_output_token) + k + 1
            for prompt, prev_output_token in zip(prompts, prev_output_tokens)
        ]

        seq_group_metadata_list = create_seq_group_metadata_from_prompts(
            prompts, num_gpu_blocks, block_size, final_prompt_lens,
            prev_output_tokens, seq_ids)
    return seq_group_metadata_list, prompts, prev_output_tokens