Unverified Commit bb3605db authored by qizixi's avatar qizixi Committed by GitHub
Browse files

[Bugfix] Fix v1/spec_decode/test_ngram.py (#16895)


Signed-off-by: default avatarqizixi <qizixi@meta.com>
parent fe742aef
......@@ -2,6 +2,7 @@
import numpy as np
from vllm.config import SpeculativeConfig, VllmConfig
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
_find_subarray_kmp,
_kmp_lps_array)
......@@ -39,50 +40,40 @@ def test_find_subarray_kmp():
def test_ngram_proposer():
proposer = NgramProposer()
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
return NgramProposer(vllm_config=VllmConfig(
speculative_config=SpeculativeConfig.from_dict(
{
"prompt_lookup_min": min_n,
"prompt_lookup_max": max_n,
"num_speculative_tokens": k,
"method": "ngram",
})))
# No match.
result = proposer.propose(
context_token_ids=np.array([1, 2, 3, 4, 5]),
min_n=2,
max_n=2,
k=2,
)
result = ngram_proposer(
2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
assert result is None
# No match for 4-gram.
result = proposer.propose(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
min_n=4,
max_n=4,
k=2,
)
result = ngram_proposer(
4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert result is None
# No match for 4-gram but match for 3-gram.
result = proposer.propose(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
min_n=3,
max_n=4,
k=2,
)
result = ngram_proposer(
3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert np.array_equal(result, np.array([4, 1]))
# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
result = proposer.propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]),
min_n=3,
max_n=4,
k=2,
)
result = ngram_proposer(3, 4, 2).propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram.
result = proposer.propose(
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]),
min_n=2,
max_n=4,
k=2,
)
result = ngram_proposer(
2, 4,
2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
......@@ -120,7 +120,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
......@@ -266,7 +266,7 @@ class ModelConfig:
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
......@@ -1624,7 +1624,7 @@ class ParallelConfig:
"""The full name of the worker class to use. If "auto", the worker class
will be determined based on the platform."""
sd_worker_cls: str = "auto"
"""The full name of the worker class to use for speculative decofing.
"""The full name of the worker class to use for speculative decofing.
If "auto", the worker class will be determined based on the platform."""
worker_extension_cls: str = ""
"""The full name of the worker extension class to use. The worker extension
......@@ -1815,13 +1815,13 @@ class SchedulerConfig:
max_num_batched_tokens: int = None # type: ignore
"""Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
max_num_seqs: int = None # type: ignore
"""Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
......@@ -1867,7 +1867,7 @@ class SchedulerConfig:
# TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = field(init=False)
"""Multimodal encoder compute budget, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
......@@ -2306,7 +2306,8 @@ class SpeculativeConfig:
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
# mtp acceleration for more models besides deepseek_v3
if self.target_model_config.hf_text_config.model_type \
if self.target_model_config and \
self.target_model_config.hf_text_config.model_type \
== "deepseek_v3":
# use the draft model from the same model:
self.model = self.target_model_config.model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment