test_utils.py 3.88 KB
Newer Older
1
from unittest.mock import MagicMock
2
3

import pytest
4
import torch
5

6
7
8
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
9
10
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import split_batch_by_proposal_len
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57


def test_get_all_seq_ids():
    """Verify get_all_seq_ids extracts all seq ids.
    """
    expected_seq_ids = list(range(10)) + list(range(100, 110))

    seq_group_metadata_list = [
        SequenceGroupMetadata(
            request_id=str(seq_id),
            is_prompt=True,
            seq_data={
                seq_id: MagicMock(),
            },
            sampling_params=MagicMock(),
            block_tables={
                seq_id: MagicMock(),
            },
            lora_request=None,
        ) for seq_id in expected_seq_ids
    ]

    actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
    assert actual_seq_ids == expected_seq_ids


@pytest.fixture
def fake_sequence_group_metadata():
    seq_ids = list(range(3))
    return [
        SequenceGroupMetadata(
            request_id=str(i),
            is_prompt=True,
            seq_data={
                i: MagicMock(),
            },
            sampling_params=MagicMock(),
            block_tables={
                i: MagicMock(),
            },
            lora_request=None,
        ) for i in seq_ids
    ]


def test_filter_zero_length_proposals(fake_sequence_group_metadata):
    proposal_lens = [0, 1, 0]
58
59
60
    _, (filtered_groups,
        indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
                                               proposal_lens)
61
62
63
64
65
66
67
68
69
70
71
72

    expected_groups = [
        fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
    ]
    expected_indices = [0, 2]

    assert filtered_groups == expected_groups
    assert indices == expected_indices


def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
    proposal_lens = [0, 1, 2]
73
74
75
    (filtered_groups,
     indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
                                               proposal_lens)
76
77
78
79
80
81
82
83
84
85
86

    expected_groups = [
        fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
    ]
    expected_indices = [1, 2]

    assert filtered_groups == expected_groups
    assert indices == expected_indices


def test_empty_inputs():
87
    _, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
88
89
90
91
92
93
94

    assert filtered_groups == []
    assert indices == []


def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
    proposal_lens = [0, 0, 0]
95
96
97
    (filtered_groups,
     indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
                                               proposal_lens)
98
99
100
101
102
103
104

    assert filtered_groups == []
    assert indices == []


def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
    proposal_lens = [1, 1, 1]
105
106
107
    _, (filtered_groups,
        indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
                                               proposal_lens)
108
109
110

    assert filtered_groups == []
    assert indices == []
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


def mock_spec_decode_sampler(acceptance_sampler_method):
    """
    Returns either a RejectionSampler or TypicalAcceptanceSampler
    object depending on whether acceptance_sampler_method is 
    'rejection_sampler' or 'typical_acceptance_sampler' respectively.
    """
    if acceptance_sampler_method == "rejection_sampler":
        sampler = MagicMock(spec=RejectionSampler)
        sampler.token_id_dtype = torch.int64
        return sampler
    elif acceptance_sampler_method == "typical_acceptance_sampler":
        sampler = MagicMock(spec=TypicalAcceptanceSampler)
        sampler.token_id_dtype = torch.int64
        return sampler
    else:
        raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")