utils.py 1.7 KB
Newer Older
1
2
3
4
import time
from typing import Tuple

from vllm import SamplingParams
5
from vllm.sequence import Logprob, Sequence, SequenceGroup
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20


def create_dummy_prompt(
        request_id: str,
        prompt_length: int,
        block_size: int = None) -> Tuple[Sequence, SequenceGroup]:
    if not block_size:
        block_size = prompt_length

    # Create dummy prompt sequence with tokens 0...block_size-1
    # and prompt "0 ... block_size".
    prompt_tokens = list(range(prompt_length))
    prompt_str = " ".join([str(t) for t in prompt_tokens])
    prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
    seq_group = SequenceGroup(request_id, [prompt], SamplingParams(),
Cade Daniel's avatar
Cade Daniel committed
21
                              time.time(), None)
22
23
24
25

    return prompt, seq_group


26
def create_seq_group(
27
    seq_prompt_len=1024,
28
29
30
31
32
33
34
    seq_output_lens=(128, ),
    request_id='0',
    seq_id_start=0,
) -> SequenceGroup:

    assert len(seq_output_lens) > 0

35
    prompt_token_ids = [0] * seq_prompt_len
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    seqs = []
    for seq_id_offset, output_len in enumerate(seq_output_lens):
        seq = Sequence(
            seq_id=seq_id_start + seq_id_offset,
            prompt="",
            prompt_token_ids=prompt_token_ids,
            block_size=16,
        )

        for i in range(output_len):
            seq.append_token_id(
                token_id=i,
                logprobs={i: Logprob(0.0)},
            )
        seqs.append(seq)

    seq_group = SequenceGroup(
        request_id=request_id,
        seqs=seqs,
        sampling_params=SamplingParams(),
        arrival_time=time.time(),
    )

    return seq_group


63
64
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
    return (seq_len + block_size - 1) // block_size