test_ngram.py 4.17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

zhuwenwen's avatar
zhuwenwen committed
4
import os
5
import numpy as np
6

7
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
8
9
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
                                                _find_subarray_kmp,
10
                                                _kmp_lps_array)
zhuwenwen's avatar
zhuwenwen committed
11
from ...utils import models_path_prefix
12
13


14
15
16
17
18
19
20
21
22
def test_kmp_lps_array():
    np.testing.assert_array_equal(_kmp_lps_array(np.array([])), np.array([]))
    np.testing.assert_array_equal(_kmp_lps_array(np.array([1])), np.array([0]))
    np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 1, 1])),
                                  np.array([0, 1, 2]))
    np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 3, 4])),
                                  np.array([0, 0, 0, 0]))
    np.testing.assert_array_equal(_kmp_lps_array(np.array([1, 2, 1, 2, 3])),
                                  np.array([0, 0, 1, 2, 0]))
23
24


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def test_find_subarray_kmp():
    X = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
    assert _find_subarray_kmp(X, 2, 2) is None
    X = np.array([1, 2, 3, 4, 1, 2, 3])
    np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
                                  np.array([4, 1, 2]))
    np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 2), np.array([4,
                                                                         1]))
    np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
                                  np.array([4, 1, 2]))
    np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 2), np.array([4,
                                                                         1]))
    X = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
    np.testing.assert_array_equal(_find_subarray_kmp(X, 2, 3),
                                  np.array([4, 1, 2]))
40
    # Return on the first match
41
42
    np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
                                  np.array([6, 2, 3]))
43
44
45


def test_ngram_proposer():
46
47

    def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
48
        # Dummy model config. Just to set max_model_len.
zhuwenwen's avatar
zhuwenwen committed
49
        model_config = ModelConfig(model=os.path.join(models_path_prefix, "facebook/opt-125m"),
50
51
                                   task="generate",
                                   max_model_len=100,
zhuwenwen's avatar
zhuwenwen committed
52
                                   tokenizer=os.path.join(models_path_prefix, "facebook/opt-125m"),
53
54
55
56
57
58
59
60
61
62
63
64
65
                                   tokenizer_mode="auto",
                                   dtype="auto",
                                   seed=None,
                                   trust_remote_code=False)
        return NgramProposer(
            vllm_config=VllmConfig(model_config=model_config,
                                   speculative_config=SpeculativeConfig.
                                   from_dict({
                                       "prompt_lookup_min": min_n,
                                       "prompt_lookup_max": max_n,
                                       "num_speculative_tokens": k,
                                       "method": "ngram",
                                   })))
66
67

    # No match.
68
69
    result = ngram_proposer(
        2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
70
71
72
    assert result is None

    # No match for 4-gram.
73
74
    result = ngram_proposer(
        4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
75
76
77
    assert result is None

    # No match for 4-gram but match for 3-gram.
78
79
    result = ngram_proposer(
        3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
80
81
82
83
    assert np.array_equal(result, np.array([4, 1]))

    # Match for both 4-gram and 3-gram.
    # In this case, the proposer should return the 4-gram match.
84
85
    result = ngram_proposer(3, 4, 2).propose(
        context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
86
87
88
    assert np.array_equal(result, np.array([1, 2]))  # Not [5, 1]

    # Match for 2-gram and 3-gram, but not 4-gram.
89
90
91
    result = ngram_proposer(
        2, 4,
        2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
zhuwenwen's avatar
zhuwenwen committed
92
    assert np.array_equal(result, np.array([1, 2]))  # Not [5, 2]