utils.py 8.36 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
7
from collections.abc import Sequence as GenericSequence
from typing import Any, Optional
8

9
10
import torch

11
from vllm import SamplingParams
12
from vllm.core.scheduler import Scheduler, SchedulerOutputs
13
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
14
from vllm.lora.request import LoRARequest
15
16
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
                           SequenceGroupMetadata)
17
18
19


def create_dummy_prompt(
20
    request_id: str,
21
    prompt_length: int = -1,
22
23
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
24
    prompt_tokens: Optional[list[int]] = None,
25
    prompt_embeds: Optional[torch.Tensor] = None,
26
27
    min_tokens: int = 0,
    max_tokens: int = 16,
28
) -> tuple[Sequence, SequenceGroup]:
29
30
31
    if not block_size:
        block_size = prompt_length

32
33
34
35
    if prompt_tokens is None:
        # Create dummy prompt sequence with tokens 0...block_size-1
        # and prompt "0 ... block_size".
        prompt_tokens = list(range(prompt_length))
36

37
    prompt_str = " ".join([str(t) for t in prompt_tokens])
38
39
40
41
    inputs = token_inputs(
        prompt_token_ids=prompt_tokens,
        prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
            prompt_embeds=prompt_embeds)
42
43
    prompt = Sequence(
        int(request_id),
44
        inputs=inputs,
45
46
47
48
49
50
51
52
53
54
        block_size=block_size,
    )
    seq_group = SequenceGroup(
        request_id=request_id,
        seqs=[prompt],
        arrival_time=time.time(),
        sampling_params=SamplingParams(max_tokens=max_tokens,
                                       min_tokens=min_tokens),
        lora_request=lora_request,
    )
55
56
57
58

    return prompt, seq_group


59
def create_dummy_lora_sequence(request_id: int, token_ids: list[int],
60
61
62
63
64
65
66
67
68
                               block_size: int, lora_int_id: int) -> Sequence:
    return Sequence(seq_id=request_id,
                    inputs=token_inputs(token_ids),
                    block_size=block_size,
                    lora_request=LoRARequest(lora_name="dummy",
                                             lora_path="/dummy",
                                             lora_int_id=lora_int_id))


69
def create_dummy_sequence(request_id: int, token_ids: list[int],
70
71
72
73
74
75
76
77
                          block_size: int) -> Sequence:
    return Sequence(
        seq_id=request_id,
        inputs=token_inputs(token_ids),
        block_size=block_size,
    )


78
79
80
81
82
83
def create_dummy_prompt_encoder_decoder(
    request_id: str,
    decoder_prompt_length: int,
    encoder_prompt_length: int,
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
84
) -> tuple[Sequence, Sequence, SequenceGroup]:
85
86
87
88
    if not block_size:
        block_size = decoder_prompt_length

    # Create dummy prompt sequence with tokens 0...block_size-1
89
90
    # and prompt "0 ... block_size". Note that the prompt string
    # doesn't actually match the tokens
91
92
    decoder_prompt_tokens = list(range(decoder_prompt_length))
    decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
93
94
95
    encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
    encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])

96
97
98
99
100
    inputs: EncoderDecoderInputs = {
        "decoder": token_inputs(decoder_prompt_tokens,
                                prompt=decoder_prompt_str),
        "encoder": token_inputs(encoder_prompt_tokens,
                                prompt=encoder_prompt_str),
101
    }
102
103

    decoder_prompt = Sequence(int(request_id),
104
105
                              inputs=inputs["decoder"],
                              block_size=block_size)
106
107

    encoder_prompt = Sequence(int(request_id),
108
109
110
                              inputs=inputs["encoder"],
                              block_size=block_size)

111
112
113
114
115
116
117
118
119
    seq_group = SequenceGroup(request_id=request_id,
                              seqs=[decoder_prompt],
                              arrival_time=time.time(),
                              lora_request=lora_request,
                              encoder_seq=encoder_prompt)

    return decoder_prompt, encoder_prompt, seq_group


120
def create_seq_group(
121
        seq_prompt_len: int = 1024,
122
        seq_output_lens: GenericSequence[int] = (128, ),
123
124
125
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
126
127
128

    assert len(seq_output_lens) > 0

129
130
131
    if sampling_params is None:
        sampling_params = SamplingParams()

132
    prompt_token_ids = [0] * seq_prompt_len
133

134
    seqs: list[Sequence] = []
135
136
137
    for seq_id_offset, output_len in enumerate(seq_output_lens):
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
138
            inputs=token_inputs(prompt_token_ids),
139
140
141
142
143
144
145
146
147
148
149
150
151
            block_size=16,
        )

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

    seq_group = SequenceGroup(
        request_id=request_id,
        seqs=seqs,
152
        sampling_params=sampling_params,
153
154
155
156
157
158
        arrival_time=time.time(),
    )

    return seq_group


159
160
def create_seq_group_encoder_decoder(
        seq_prompt_len: int = 1024,
161
        seq_output_lens: GenericSequence[int] = (128, ),
162
163
164
165
166
167
168
169
170
171
172
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:

    assert len(seq_output_lens) > 0

    if sampling_params is None:
        sampling_params = SamplingParams()

    prompt_token_ids = [0] * seq_prompt_len

173
174
175
    inputs: EncoderDecoderInputs = {
        "decoder": token_inputs(prompt_token_ids),
        "encoder": token_inputs(prompt_token_ids),
176
177
    }

178
179
    seqs = []
    for seq_id_offset, output_len in enumerate(seq_output_lens):
180
        # Construct decoder input sequences
181
182
183
184
185
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
            inputs=inputs["decoder"],
            block_size=16,
        )
186
187
188
189
190
191
192
193

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

194
    # Encoder input sequence
195
196
197
198
199
    encoder_seq = Sequence(
        seq_id=seq_id_start + len(seq_output_lens),
        inputs=inputs["encoder"],
        block_size=16,
    )
200
201
202
203
204
205
206
207

    return SequenceGroup(request_id=request_id,
                         seqs=seqs,
                         sampling_params=sampling_params,
                         arrival_time=time.time(),
                         encoder_seq=encoder_seq)


208
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    return (seq_len + block_size - 1) // block_size


# Helper functions for scheduler tests


def get_sequence_groups(scheduler_output):
    return [s.seq_group for s in scheduler_output.scheduled_seq_groups]


def append_new_token(out, token_id: int):
    seq_groups = get_sequence_groups(out)
    for seq_group in seq_groups:
        for seq in seq_group.get_seqs():
            seq.append_token_id(token_id, {token_id: Logprob(token_id)})


def schedule_and_update_computed_tokens(scheduler):
227
    metas, out, _ = scheduler.schedule()
228
229
    for s in out.scheduled_seq_groups:
        s.seq_group.update_num_computed_tokens(s.token_chunk_size)
230
231
232
    return metas, out


233
234
235
236
def append_new_token_seq(seq: Sequence, token_id: int):
    seq.append_token_id(token_id, {token_id: Logprob(token_id)})


237
238
239
240
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
    seq_group.update_num_computed_tokens(token_chunk_size)
    for seq in seq_group.get_seqs():
        seq.append_token_id(token_id, {token_id: Logprob(token_id)})
241
242
243
244
245
246
247
248
249


class SchedulerProxy:
    """
    A proxy class to forward calls to the scheduler.
    """

    def __init__(self, scheduler: Scheduler):
        self.scheduler_ = scheduler
250
        self.call_history: dict[str, list[Any]] = defaultdict(list)
251
252
253
254
255
256
257
258
259
260
261

    def __getattr__(self, name: str) -> Any:

        def wrapper(*args, **kwargs):
            result = getattr(self.scheduler_, name)(*args, **kwargs)
            self.call_history[name].append((args, kwargs, result))
            return result

        return wrapper

    def last_schedule_ret(
262
        self, ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]:
263
264
        _, _, ret = self.call_history["schedule"][-1]
        return ret