utils.py 5.98 KB
Newer Older
1
import time
2
3
4
from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
5
6

from vllm import SamplingParams
7
from vllm.lora.request import LoRARequest
8
from vllm.sequence import Logprob, Sequence, SequenceGroup
9
10
11


def create_dummy_prompt(
12
13
14
15
16
17
18
    request_id: str,
    prompt_length: int,
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
    use_beam_search: bool = False,
    best_of: int = 1,
) -> Tuple[Sequence, SequenceGroup]:
19
20
21
22
23
24
25
    if not block_size:
        block_size = prompt_length

    # Create dummy prompt sequence with tokens 0...block_size-1
    # and prompt "0 ... block_size".
    prompt_tokens = list(range(prompt_length))
    prompt_str = " ".join([str(t) for t in prompt_tokens])
26
27
28
29
30
31
    prompt = Sequence(int(request_id),
                      inputs={
                          "prompt": prompt_str,
                          "prompt_token_ids": prompt_tokens,
                      },
                      block_size=block_size)
32
33
34
35
36
37
38
    seq_group = SequenceGroup(request_id=request_id,
                              seqs=[prompt],
                              arrival_time=time.time(),
                              sampling_params=SamplingParams(
                                  use_beam_search=use_beam_search,
                                  best_of=best_of),
                              lora_request=lora_request)
39
40
41
42

    return prompt, seq_group


43
44
45
46
47
48
49
50
def create_dummy_prompt_encoder_decoder(
    request_id: str,
    decoder_prompt_length: int,
    encoder_prompt_length: int,
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
    use_beam_search: bool = False,
    best_of: int = 1,
51
) -> Tuple[Sequence, Sequence, SequenceGroup]:
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    if not block_size:
        block_size = decoder_prompt_length

    # Create dummy prompt sequence with tokens 0...block_size-1
    # and prompt "0 ... block_size".
    decoder_prompt_tokens = list(range(decoder_prompt_length))
    decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])

    decoder_prompt = Sequence(int(request_id),
                              inputs={
                                  "prompt": decoder_prompt_str,
                                  "prompt_token_ids": decoder_prompt_tokens,
                                  "multi_modal_data": None,
                              },
                              block_size=block_size)

    encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
    encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
    encoder_prompt = Sequence(int(request_id),
                              inputs={
                                  "prompt": encoder_prompt_str,
                                  "prompt_token_ids": encoder_prompt_tokens,
                                  "multi_modal_data": None,
                              },
                              block_size=block_size)
    seq_group = SequenceGroup(request_id=request_id,
                              seqs=[decoder_prompt],
                              sampling_params=SamplingParams(
                                  use_beam_search=use_beam_search,
                                  best_of=best_of),
                              arrival_time=time.time(),
                              lora_request=lora_request,
                              encoder_seq=encoder_prompt)

    return decoder_prompt, encoder_prompt, seq_group


89
def create_seq_group(
90
        seq_prompt_len: int = 1024,
91
        seq_output_lens: GenericSequence[int] = (128, ),
92
93
94
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
95
96
97

    assert len(seq_output_lens) > 0

98
99
100
    if sampling_params is None:
        sampling_params = SamplingParams()

101
    prompt_token_ids = [0] * seq_prompt_len
102

103
    seqs: List[Sequence] = []
104
105
106
    for seq_id_offset, output_len in enumerate(seq_output_lens):
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
107
            inputs={"prompt_token_ids": prompt_token_ids},
108
109
110
111
112
113
114
115
116
117
118
119
120
            block_size=16,
        )

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

    seq_group = SequenceGroup(
        request_id=request_id,
        seqs=seqs,
121
        sampling_params=sampling_params,
122
123
124
125
126
127
        arrival_time=time.time(),
    )

    return seq_group


128
129
def create_seq_group_encoder_decoder(
        seq_prompt_len: int = 1024,
130
        seq_output_lens: GenericSequence[int] = (128, ),
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:

    assert len(seq_output_lens) > 0

    if sampling_params is None:
        sampling_params = SamplingParams()

    prompt_token_ids = [0] * seq_prompt_len

    seqs = []
    for seq_id_offset, output_len in enumerate(seq_output_lens):
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
            inputs={
                "prompt": "",
                "prompt_token_ids": prompt_token_ids,
                "multi_modal_data": None,
            },
            block_size=16,
        )

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

    # Encoder sequence
    encoder_seq = Sequence(
        seq_id=seq_id_start + len(seq_output_lens),
        inputs={
            "prompt": "",
            "prompt_token_ids": prompt_token_ids,
            "multi_modal_data": None,
        },
        block_size=16,
    )

    return SequenceGroup(request_id=request_id,
                         seqs=seqs,
                         sampling_params=sampling_params,
                         arrival_time=time.time(),
                         encoder_seq=encoder_seq)


179
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
180
    return (seq_len + block_size - 1) // block_size