test_ngram.py 8.41 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import numpy as np
4

5
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
6
7
from vllm.v1.spec_decode.ngram_proposer import (
    NgramProposer, _find_longest_matched_ngram_and_propose_tokens)
8
9


10
11
def test_find_longest_matched_ngram_and_propose_tokens():
    tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
12
13
14
15
16
17
18
    result = _find_longest_matched_ngram_and_propose_tokens(
        origin_tokens=tokens,
        min_ngram=2,
        max_ngram=2,
        max_model_len=1024,
        k=2)
    assert len(result) == 0
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    tokens = np.array([1, 2, 3, 4, 1, 2, 3])
    np.testing.assert_array_equal(
        _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
                                                       min_ngram=2,
                                                       max_ngram=2,
                                                       max_model_len=1024,
                                                       k=3),
        np.array([4, 1, 2]))
    np.testing.assert_array_equal(
        _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
                                                       min_ngram=2,
                                                       max_ngram=2,
                                                       max_model_len=1024,
                                                       k=2), np.array([4, 1]))
    np.testing.assert_array_equal(
        _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
                                                       min_ngram=1,
                                                       max_ngram=1,
                                                       max_model_len=1024,
                                                       k=3),
        np.array([4, 1, 2]))
    np.testing.assert_array_equal(
        _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
                                                       min_ngram=1,
                                                       max_ngram=1,
                                                       max_model_len=1024,
                                                       k=2), np.array([4, 1]))
47

48
49
50
51
52
53
54
55
    tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
    np.testing.assert_array_equal(
        _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
                                                       min_ngram=2,
                                                       max_ngram=2,
                                                       max_model_len=1024,
                                                       k=3),
        np.array([4, 1, 2]))
56
    # Return on the first match
57
58
59
60
61
62
    np.testing.assert_array_equal(
        _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
                                                       min_ngram=1,
                                                       max_ngram=1,
                                                       max_model_len=1024,
                                                       k=2), np.array([6, 2]))
63
64
65


def test_ngram_proposer():
66

67
    def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
68
        # Dummy model config. Just to set max_model_len.
69
        model_config = ModelConfig(model="facebook/opt-125m")
70
71
        return NgramProposer(
            vllm_config=VllmConfig(model_config=model_config,
72
73
74
75
76
77
                                   speculative_config=SpeculativeConfig(
                                       prompt_lookup_min=min_n,
                                       prompt_lookup_max=max_n,
                                       num_speculative_tokens=k,
                                       method="ngram",
                                   )))
78
79

    # No match.
80
81
82
83
84
85
86
87
88
    token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
    result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
        sampled_token_ids=[[0]],
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 0
89
90

    # No match for 4-gram.
91
92
93
94
95
96
97
98
99
    token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
    result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
        sampled_token_ids=[[0]],
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 0
100
101

    # No match for 4-gram but match for 3-gram.
102
103
104
105
106
107
108
109
110
    token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
    result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
        sampled_token_ids=[[0]],
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert np.array_equal(result, np.array([[4, 1]]))
111
112
113

    # Match for both 4-gram and 3-gram.
    # In this case, the proposer should return the 4-gram match.
114
115
116
117
118
119
120
121
122
    token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
    result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
        sampled_token_ids=[[0]],
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert np.array_equal(result, np.array([[1, 2]]))  # Not [5, 1]]
123
124

    # Match for 2-gram and 3-gram, but not 4-gram.
125
126
127
128
129
130
131
132
133
    token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
    result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
        sampled_token_ids=[[0]],
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert np.array_equal(result, np.array([[1, 2]]))  # Not [5, 2]]
134
135

    # Multiple 3-gram matched, but always pick the first one.
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    token_ids_cpu = np.array(
        [[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
    result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
        sampled_token_ids=[[0]],
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert np.array_equal(result, np.array([[100, 1]]))

    # check empty input
    token_ids_cpu = np.array([[]])
    result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
        sampled_token_ids=[[0]],
        req_ids=["0"],
        num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 0

    # check multibatch input
    # first request has 5 tokens and a match
    # second request has 3 tokens and no match. Padded with -1 for max len 5
    token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
    result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
        sampled_token_ids=[[0], [1]],
        req_ids=["0", "1"],
        num_tokens_no_spec=np.array([5, 3]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 2
    assert np.array_equal(result[0], np.array([3, 1]))
    assert np.array_equal(result[1], np.array([]))

    # test if 0 threads available: can happen if TP size > CPU count
    ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2)
    ngram_proposer.num_numba_thread_available = 0
    # set max_model_len to 2 * threshold to ensure multithread is used
    num_tokens_threshold = ngram_proposer.num_tokens_threshold
    ngram_proposer.max_model_len = 2 * num_tokens_threshold
    # using multibatch test
    middle_integer = num_tokens_threshold // 2
    input_1 = [_ for _ in range(num_tokens_threshold)]
    input_1 += [middle_integer, middle_integer + 1]
    input_2 = [-1] * len(input_1)
    input_2[:3] = [4, 5, 6]
    token_ids_cpu = np.array([input_1, input_2])
    result = ngram_proposer.propose(
        sampled_token_ids=[[0], [1]],
        req_ids=["0", "1"],
        num_tokens_no_spec=np.array([len(input_1), 3]),
        token_ids_cpu=token_ids_cpu,
        spec_decode_unsupported_reqs=(),
    )
    assert len(result[0]) == 2
    assert np.array_equal(result[0],
                          np.array([middle_integer + 2, middle_integer + 3]))
    assert np.array_equal(result[1], np.array([]))