"vscode:/vscode.git/clone" did not exist on "215bf150a61ed3aafba35df65c95f63c82cc94b2"
utils.py 5.88 KB
Newer Older
1
import time
2
from typing import Iterable, Optional, Tuple
3
4

from vllm import SamplingParams
5
from vllm.lora.request import LoRARequest
6
from vllm.sequence import Logprob, Sequence, SequenceGroup
7
8
9


def create_dummy_prompt(
10
11
12
13
14
15
16
    request_id: str,
    prompt_length: int,
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
    use_beam_search: bool = False,
    best_of: int = 1,
) -> Tuple[Sequence, SequenceGroup]:
17
18
19
20
21
22
23
    if not block_size:
        block_size = prompt_length

    # Create dummy prompt sequence with tokens 0...block_size-1
    # and prompt "0 ... block_size".
    prompt_tokens = list(range(prompt_length))
    prompt_str = " ".join([str(t) for t in prompt_tokens])
24
25
26
27
28
29
    prompt = Sequence(int(request_id),
                      inputs={
                          "prompt": prompt_str,
                          "prompt_token_ids": prompt_tokens,
                      },
                      block_size=block_size)
30
31
32
33
34
35
36
    seq_group = SequenceGroup(request_id=request_id,
                              seqs=[prompt],
                              arrival_time=time.time(),
                              sampling_params=SamplingParams(
                                  use_beam_search=use_beam_search,
                                  best_of=best_of),
                              lora_request=lora_request)
37
38
39
40

    return prompt, seq_group


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def create_dummy_prompt_encoder_decoder(
    request_id: str,
    decoder_prompt_length: int,
    encoder_prompt_length: int,
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
    use_beam_search: bool = False,
    best_of: int = 1,
) -> Tuple[Sequence, SequenceGroup]:
    if not block_size:
        block_size = decoder_prompt_length

    # Create dummy prompt sequence with tokens 0...block_size-1
    # and prompt "0 ... block_size".
    decoder_prompt_tokens = list(range(decoder_prompt_length))
    decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])

    decoder_prompt = Sequence(int(request_id),
                              inputs={
                                  "prompt": decoder_prompt_str,
                                  "prompt_token_ids": decoder_prompt_tokens,
                                  "multi_modal_data": None,
                              },
                              block_size=block_size)

    encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
    encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
    encoder_prompt = Sequence(int(request_id),
                              inputs={
                                  "prompt": encoder_prompt_str,
                                  "prompt_token_ids": encoder_prompt_tokens,
                                  "multi_modal_data": None,
                              },
                              block_size=block_size)
    seq_group = SequenceGroup(request_id=request_id,
                              seqs=[decoder_prompt],
                              sampling_params=SamplingParams(
                                  use_beam_search=use_beam_search,
                                  best_of=best_of),
                              arrival_time=time.time(),
                              lora_request=lora_request,
                              encoder_seq=encoder_prompt)

    return decoder_prompt, encoder_prompt, seq_group


87
def create_seq_group(
88
89
90
91
92
        seq_prompt_len: int = 1024,
        seq_output_lens: Iterable[int] = (128, ),
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
93
94
95

    assert len(seq_output_lens) > 0

96
97
98
    if sampling_params is None:
        sampling_params = SamplingParams()

99
    prompt_token_ids = [0] * seq_prompt_len
100
101
102
103
104

    seqs = []
    for seq_id_offset, output_len in enumerate(seq_output_lens):
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
105
            inputs={"prompt_token_ids": prompt_token_ids},
106
107
108
109
110
111
112
113
114
115
116
117
118
            block_size=16,
        )

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

    seq_group = SequenceGroup(
        request_id=request_id,
        seqs=seqs,
119
        sampling_params=sampling_params,
120
121
122
123
124
125
        arrival_time=time.time(),
    )

    return seq_group


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def create_seq_group_encoder_decoder(
        seq_prompt_len: int = 1024,
        seq_output_lens: Iterable[int] = (128, ),
        request_id: str = '0',
        seq_id_start: int = 0,
        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:

    assert len(seq_output_lens) > 0

    if sampling_params is None:
        sampling_params = SamplingParams()

    prompt_token_ids = [0] * seq_prompt_len

    seqs = []
    for seq_id_offset, output_len in enumerate(seq_output_lens):
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
            inputs={
                "prompt": "",
                "prompt_token_ids": prompt_token_ids,
                "multi_modal_data": None,
            },
            block_size=16,
        )

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

    # Encoder sequence
    encoder_seq = Sequence(
        seq_id=seq_id_start + len(seq_output_lens),
        inputs={
            "prompt": "",
            "prompt_token_ids": prompt_token_ids,
            "multi_modal_data": None,
        },
        block_size=16,
    )

    return SequenceGroup(request_id=request_id,
                         seqs=seqs,
                         sampling_params=sampling_params,
                         arrival_time=time.time(),
                         encoder_seq=encoder_seq)


177
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
178
    return (seq_len + block_size - 1) // block_size