test_sequence.py 3.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pytest

6
from vllm.model_executor.layers.sampler import SamplerOutput
7
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
8
                           SequenceOutput)
9

10
11
from .core.utils import create_dummy_prompt

12
13
14
15

@pytest.fixture
def sample_outputs():
    return [
16
        CompletionSequenceGroupOutput(samples=[
17
18
            SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
        ],
19
                                      prompt_logprobs=None) for i in range(5)
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    ]


@pytest.fixture
def sampler_output(sample_outputs):
    return SamplerOutput(outputs=sample_outputs)


def test_sampler_output_initialization(sampler_output, sample_outputs):
    assert len(sampler_output) == len(sample_outputs)
    assert sampler_output.sampled_token_probs is None
    assert sampler_output.sampled_token_ids is None


def test_sampler_output_getitem(sampler_output, sample_outputs):
    assert sampler_output[2] == sample_outputs[2]


def test_sampler_output_setitem(sampler_output):
39
    new_output = CompletionSequenceGroupOutput(samples=[
40
41
        SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
    ],
42
                                               prompt_logprobs=None)
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    sampler_output[2] = new_output
    assert sampler_output[2] == new_output


def test_sampler_output_len(sampler_output, sample_outputs):
    assert len(sampler_output) == len(sample_outputs)


def test_sampler_output_eq(sample_outputs):
    sampler_output1 = SamplerOutput(outputs=sample_outputs)
    sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
    sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
    assert sampler_output1 == sampler_output2
    assert sampler_output1 != sampler_output3
57
58
59


def test_sequence_data_prefill():
60
    seq_data = SequenceData.from_seqs([1, 2, 3, 4])
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    assert seq_data.get_num_uncomputed_tokens() == 4
    assert seq_data.get_num_computed_tokens() == 0
    # advance by 2
    seq_data.update_num_computed_tokens(2)
    assert seq_data.get_num_uncomputed_tokens() == 2
    assert seq_data.get_num_computed_tokens() == 2

    # advance by 1
    seq_data.update_num_computed_tokens(1)
    assert seq_data.get_num_uncomputed_tokens() == 1
    assert seq_data.get_num_computed_tokens() == 3

    # append tokens and reset, simulating recompute
    seq_data.append_token_id(1, logprob=0.0)
75
    seq_data.reset_state_for_recompute()
76
77
    assert seq_data.get_num_uncomputed_tokens() == 5
    assert seq_data.get_num_computed_tokens() == 0
78
79
80


def test_sequence_group_stage():
81
    _, seq_group = create_dummy_prompt("1", 12)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(6)
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(5)
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(1)
    assert seq_group.is_prefill() is False
    seqs = seq_group.get_seqs()
    assert len(seqs) == 1
    seqs[0].data.append_token_id(1, logprob=0.0)
    for seq in seq_group.get_seqs():
        seq.reset_state_for_recompute()
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(5)
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(7)
    assert seq_group.is_prefill() is True
    seq_group.update_num_computed_tokens(1)
    assert seq_group.is_prefill() is False