test_ngram.py 8.42 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import numpy as np
4

5
6
7
8
9
from vllm.config import (
    ModelConfig,
    SpeculativeConfig,
    VllmConfig,
)
10
from vllm.v1.spec_decode.ngram_proposer import (
11
12
13
    NgramProposer,
    _find_longest_matched_ngram_and_propose_tokens,
)
14
15


16
17
def test_find_longest_matched_ngram_and_propose_tokens():
    tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
18
    result = _find_longest_matched_ngram_and_propose_tokens(
19
20
        origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
    )
21
    assert len(result) == 0
22

23
24
    tokens = np.array([1, 2, 3, 4, 1, 2, 3])
    np.testing.assert_array_equal(
25
26
27
28
29
        _find_longest_matched_ngram_and_propose_tokens(
            origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
        ),
        np.array([4, 1, 2]),
    )
30
    np.testing.assert_array_equal(
31
32
33
34
35
        _find_longest_matched_ngram_and_propose_tokens(
            origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
        ),
        np.array([4, 1]),
    )
36
    np.testing.assert_array_equal(
37
38
39
40
41
        _find_longest_matched_ngram_and_propose_tokens(
            origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=3
        ),
        np.array([4, 1, 2]),
    )
42
    np.testing.assert_array_equal(
43
44
45
46
47
        _find_longest_matched_ngram_and_propose_tokens(
            origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
        ),
        np.array([4, 1]),
    )
48

49
50
    tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
    np.testing.assert_array_equal(
51
52
53
54
55
        _find_longest_matched_ngram_and_propose_tokens(
            origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
        ),
        np.array([4, 1, 2]),
    )
56
    # Return on the first match
57
    np.testing.assert_array_equal(
58
59
60
61
62
        _find_longest_matched_ngram_and_propose_tokens(
            origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
        ),
        np.array([6, 2]),
    )
63
64
65


def test_ngram_proposer():
66
    def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
67
        # Dummy model config. Just to set max_model_len.
68
        model_config = ModelConfig(model="facebook/opt-125m")
69
        return NgramProposer(
70
71
72
73
74
75
76
77
78
79
            vllm_config=VllmConfig(
                model_config=model_config,
                speculative_config=SpeculativeConfig(
                    prompt_lookup_min=min_n,
                    prompt_lookup_max=max_n,
                    num_speculative_tokens=k,
                    method="ngram",
                ),
            )
        )
80
81

    # No match.
82
83
    token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
    result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
84
        sampled_token_ids=[[0]],
85
86
87
88
89
90
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 0
91
92

    # No match for 4-gram.
93
94
    token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
    result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
95
        sampled_token_ids=[[0]],
96
97
98
99
100
101
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 0
102
103

    # No match for 4-gram but match for 3-gram.
104
105
    token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
    result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
106
        sampled_token_ids=[[0]],
107
108
109
110
111
112
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert np.array_equal(result, np.array([[4, 1]]))
113
114
115

    # Match for both 4-gram and 3-gram.
    # In this case, the proposer should return the 4-gram match.
116
117
    token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
    result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
118
        sampled_token_ids=[[0]],
119
120
121
122
123
124
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert np.array_equal(result, np.array([[1, 2]]))  # Not [5, 1]]
125
126

    # Match for 2-gram and 3-gram, but not 4-gram.
127
128
    token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
    result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
129
        sampled_token_ids=[[0]],
130
131
132
133
134
135
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert np.array_equal(result, np.array([[1, 2]]))  # Not [5, 2]]
136
137

    # Multiple 3-gram matched, but always pick the first one.
138
    token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
139
    result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
140
        sampled_token_ids=[[0]],
141
142
143
144
145
146
147
148
149
150
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert np.array_equal(result, np.array([[100, 1]]))

    # check empty input
    token_ids_cpu = np.array([[]])
    result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
151
        sampled_token_ids=[[0]],
152
153
154
155
156
157
158
159
160
161
162
163
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 0

    # check multibatch input
    # first request has 5 tokens and a match
    # second request has 3 tokens and no match. Padded with -1 for max len 5
    token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
    result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
164
        sampled_token_ids=[[0], [1]],
165
166
167
168
169
170
171
172
173
        req_ids=["0", "1"],
        num_tokens_no_spec=np.array([5, 3]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 2
    assert np.array_equal(result[0], np.array([3, 1]))
    assert np.array_equal(result[1], np.array([]))

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    # Test non-contiguous indices: requests 0 and 2 need proposals,
    # request 1 is in prefill
    proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
    max_model_len = 20
    token_ids_cpu = np.zeros((3, max_model_len), dtype=np.int32)
    token_ids_cpu[0, :5] = [1, 2, 3, 1, 2]
    token_ids_cpu[1, :3] = [4, 5, 6]
    token_ids_cpu[2, :5] = [7, 8, 9, 7, 8]
    num_tokens_no_spec = np.array([5, 3, 5], dtype=np.int32)
    sampled_token_ids = [[2], [], [8]]  # Empty list for request 1 simulates prefill
    result = proposer.propose(
        sampled_token_ids=sampled_token_ids,
        req_ids=["0", "1", "2"],
        num_tokens_no_spec=num_tokens_no_spec,
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result) == 3
    assert np.array_equal(result[0], [3, 1])
    assert len(result[1]) == 0
    assert np.array_equal(result[2], [9, 7])
    # Verify internal arrays written to correct indices
    assert proposer.valid_ngram_num_drafts[0] == 2
    assert proposer.valid_ngram_num_drafts[1] == 0
    assert proposer.valid_ngram_num_drafts[2] == 2
    assert np.array_equal(proposer.valid_ngram_draft[0, :2], [3, 1])
    assert np.array_equal(proposer.valid_ngram_draft[2, :2], [9, 7])

202
203
204
205
206
207
208
209
210
211
212
213
214
215
    # test if 0 threads available: can happen if TP size > CPU count
    ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
    ngram_proposer.num_numba_thread_available = 0
    # set max_model_len to 2 * threshold to ensure multithread is used
    num_tokens_threshold = ngram_proposer.num_tokens_threshold
    ngram_proposer.max_model_len = 2 * num_tokens_threshold
    # using multibatch test
    middle_integer = num_tokens_threshold // 2
    input_1 = [_ for _ in range(num_tokens_threshold)]
    input_1 += [middle_integer, middle_integer + 1]
    input_2 = [-1] * len(input_1)
    input_2[:3] = [4, 5, 6]
    token_ids_cpu = np.array([input_1, input_2])
    result = ngram_proposer.propose(
216
        sampled_token_ids=[[0], [1]],
217
218
219
220
221
222
        req_ids=["0", "1"],
        num_tokens_no_spec=np.array([len(input_1), 3]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 2
223
    assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3]))
224
    assert np.array_equal(result[1], np.array([]))