test_sequence.py 4.11 KB
Newer Older
1
2
3
import time
from typing import Optional

4
5
import pytest

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from vllm import SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import (SamplerOutput, Sequence, SequenceData,
                           SequenceGroup, SequenceGroupOutput, SequenceOutput)


def create_dummy_prompt(
    request_id: str,
    prompt_length: int,
    block_size: Optional[int] = None,
    lora_request: Optional[LoRARequest] = None,
    use_beam_search: bool = False,
    best_of: int = 1,
) -> SequenceGroup:
    if not block_size:
        block_size = prompt_length

    # Create dummy prompt sequence with tokens 0...block_size-1
    # and prompt "0 ... block_size".
    prompt_tokens = list(range(prompt_length))
    prompt_str = " ".join([str(t) for t in prompt_tokens])
    prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
    seq_group = SequenceGroup(
        request_id, [prompt],
        SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
        time.time(), lora_request)

    return seq_group
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


@pytest.fixture
def sample_outputs():
    return [
        SequenceGroupOutput(samples=[
            SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
        ],
                            prompt_logprobs=None) for i in range(5)
    ]


@pytest.fixture
def sampler_output(sample_outputs):
    return SamplerOutput(outputs=sample_outputs)


def test_sampler_output_initialization(sampler_output, sample_outputs):
    assert len(sampler_output) == len(sample_outputs)
    assert sampler_output.sampled_token_probs is None
    assert sampler_output.sampled_token_ids is None
    assert sampler_output.spec_decode_worker_metrics is None


def test_sampler_output_getitem(sampler_output, sample_outputs):
    assert sampler_output[2] == sample_outputs[2]


def test_sampler_output_setitem(sampler_output):
    new_output = SequenceGroupOutput(samples=[
        SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
    ],
                                     prompt_logprobs=None)
    sampler_output[2] = new_output
    assert sampler_output[2] == new_output


def test_sampler_output_len(sampler_output, sample_outputs):
    assert len(sampler_output) == len(sample_outputs)


def test_sampler_output_eq(sample_outputs):
    sampler_output1 = SamplerOutput(outputs=sample_outputs)
    sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
    sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
    assert sampler_output1 == sampler_output2
    assert sampler_output1 != sampler_output3
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


def test_sequence_data_prefill():
    seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
    assert seq_data.get_num_uncomputed_tokens() == 4
    assert seq_data.get_num_computed_tokens() == 0
    # advance by 2
    seq_data.update_num_computed_tokens(2)
    assert seq_data.get_num_uncomputed_tokens() == 2
    assert seq_data.get_num_computed_tokens() == 2

    # advance by 1
    seq_data.update_num_computed_tokens(1)
    assert seq_data.get_num_uncomputed_tokens() == 1
    assert seq_data.get_num_computed_tokens() == 3

    # append tokens and reset, simulating recompute
    seq_data.append_token_id(1, logprob=0.0)
99
    seq_data.reset_state_for_recompute()
100
101
    assert seq_data.get_num_uncomputed_tokens() == 5
    assert seq_data.get_num_computed_tokens() == 0
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


def test_sequence_group_stage():
    seq_group = create_dummy_prompt("1", 12)
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(6)
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(5)
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(1)
    assert seq_group.is_prefill() is False
    seqs = seq_group.get_seqs()
    assert len(seqs) == 1
    seqs[0].data.append_token_id(1, logprob=0.0)
    for seq in seq_group.get_seqs():
        seq.reset_state_for_recompute()
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(5)
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(7)
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(1)
    assert seq_group.is_prefill() is False