"vscode:/vscode.git/clone" did not exist on "b31a1fb63c98fa1c64666aaae15579439af60d95"
Unverified Commit c438b295 authored by Rahul Tuli's avatar Rahul Tuli Committed by GitHub
Browse files

feat: Enable engine-level arguments with speculators models (#25250)


Signed-off-by: default avatarRahul Tuli <rtuli@redhat.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
parent 0ff8ebb2
......@@ -3,38 +3,52 @@
import pytest
import torch
from vllm.config import SpeculativeConfig
from vllm.model_executor.models.interfaces import supports_eagle3
@pytest.mark.parametrize(
"model_path",
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
@pytest.mark.parametrize("model_path", [
pytest.param(
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized",
id="llama3-eagle3-speculator"),
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
id="qwen3-eagle3-speculator"),
])
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
monkeypatch):
"""
Test Eagle3 speculators models properly initialize speculative decoding.
This test verifies:
1. Eagle3 support is detected for the model
2. Speculative config is automatically initialized from embedded config
3. The draft model path is correctly set to the speculators model
4. Speculative tokens count is valid
5. Text generation works with speculative decoding enabled
"""
# Set environment variable for V1 engine serialization
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
# Verify Eagle3 support is detected
eagle3_supported = vllm_model.apply_model(supports_eagle3)
assert eagle3_supported
assert eagle3_supported, f"Eagle3 should be supported for {model_path}"
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
assert vllm_outputs
vllm_config = vllm_model.llm.llm_engine.vllm_config
assert isinstance(vllm_config.speculative_config, SpeculativeConfig), \
"Speculative config should be initialized for speculators model"
@pytest.mark.parametrize(
"model_path",
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
# Set environment variable for V1 engine serialization
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
spec_config = vllm_config.speculative_config
assert spec_config.num_speculative_tokens > 0, \
(f"Expected positive speculative tokens, "
f"got {spec_config.num_speculative_tokens}")
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
eagle3_supported = vllm_model.apply_model(supports_eagle3)
assert eagle3_supported
assert spec_config.model == model_path, \
f"Draft model should be {model_path}, got {spec_config.model}"
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
assert vllm_outputs
assert vllm_outputs, \
f"No outputs generated for speculators model {model_path}"
......@@ -27,8 +27,7 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
is_interleaved, maybe_override_with_speculators_target_model,
try_get_generation_config, try_get_safetensors_metadata,
is_interleaved, try_get_generation_config, try_get_safetensors_metadata,
try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
is_runai_obj_uri)
......@@ -416,15 +415,6 @@ class ModelConfig:
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
if self.runner != "draft":
# If we're not running the draft model, check for speculators config
# If speculators config, set model / tokenizer to be target model
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code)
if (backend := envs.VLLM_ATTENTION_BACKEND
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
raise ValueError(
......
......@@ -41,7 +41,8 @@ from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_ray_initialized
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.config import get_model_path, is_interleaved
from vllm.transformers_utils.config import (get_model_path, is_interleaved,
maybe_override_with_speculators)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
......@@ -1082,29 +1083,8 @@ class EngineArgs:
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine.
"""
from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.configs.speculators.base import (
SpeculatorsConfig)
if self.speculative_config is None:
hf_config = get_config(
self.hf_config_path or target_model_config.model,
self.trust_remote_code, self.revision, self.code_revision,
self.config_format)
# if loading a SpeculatorsConfig, load the speculative_config
# details from the config directly
# no user input required / expected
if isinstance(hf_config, SpeculatorsConfig):
# We create one since we don't create one
self.speculative_config = {}
self.speculative_config[
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
self.speculative_config["model"] = target_model_config.model
self.speculative_config["method"] = hf_config.method
else:
return None
return None
# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
......@@ -1139,6 +1119,15 @@ class EngineArgs:
device_config = DeviceConfig(
device=cast(Device, current_platform.device_type))
(self.model, self.tokenizer,
self.speculative_config) = maybe_override_with_speculators(
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
vllm_speculative_config=self.speculative_config,
)
model_config = self.create_model_config()
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
......
......@@ -463,15 +463,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
return config
def maybe_override_with_speculators_target_model(
def maybe_override_with_speculators(
model: str,
tokenizer: str,
trust_remote_code: bool,
revision: Optional[str] = None,
vllm_speculative_config: Optional[dict[str, Any]] = None,
**kwargs,
) -> tuple[str, str]:
) -> tuple[str, str, Optional[dict[str, Any]]]:
"""
If running a speculators config, override running model with target model
Resolve model configuration when speculators are detected.
Checks if the provided model is a speculators model and if so, extracts
the target model configuration and builds the speculative config.
Args:
model: Model name or path
tokenizer: Tokenizer name or path
trust_remote_code: Whether to trust remote code
revision: Model revision
vllm_speculative_config: Existing vLLM speculative config
Returns:
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
"""
is_gguf = check_gguf_file(model)
if is_gguf:
......@@ -487,11 +501,27 @@ def maybe_override_with_speculators_target_model(
token=_get_hf_token(),
**kwargs,
)
spec_config = config_dict.get("speculators_config", None)
# Return the target model
if spec_config is not None:
model = tokenizer = spec_config["verifier"]["name_or_path"]
return model, tokenizer
speculators_config = config_dict.get("speculators_config")
if speculators_config is None:
# No speculators config found, return original values
return model, tokenizer, vllm_speculative_config
# Speculators format detected - process overrides
from vllm.transformers_utils.configs.speculators.base import (
SpeculatorsConfig)
vllm_speculative_config = SpeculatorsConfig.extract_vllm_speculative_config(
config_dict=config_dict)
# Set the draft model to the speculators model
vllm_speculative_config["model"] = model
# Override model and tokenizer with the verifier model from config
verifier_model = speculators_config["verifier"]["name_or_path"]
model = tokenizer = verifier_model
return model, tokenizer, vllm_speculative_config
def get_config(
......
......@@ -24,6 +24,12 @@ class SpeculatorsConfig(PretrainedConfig):
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
**kwargs)
vllm_config = cls.extract_vllm_speculative_config(config_dict)
return cls(**vllm_config)
@classmethod
def extract_vllm_speculative_config(
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError(
......@@ -34,11 +40,12 @@ class SpeculatorsConfig(PretrainedConfig):
# TODO: @dsikka - use speculators pydantic model to validate
cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
vllm_config = cls.build_vllm_speculative_config(
config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
return cls(**vllm_config)
return vllm_config
@classmethod
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
......@@ -60,32 +67,45 @@ class SpeculatorsConfig(PretrainedConfig):
"'transformer_layer_config' must be a dictionary if provided")
@classmethod
def convert_speculators_to_vllm(
def build_vllm_speculative_config(
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
"""
Convert speculators config format to vLLM format.
This method handles the translation of field names and structure
between speculators and vLLM formats.
Build vLLM-compatible speculative configuration from speculators format.
This method extracts and transforms speculative configuration from the
speculators format into the structure expected by vLLM.
Args:
config_dict: Configuration dictionary in speculators format
Returns:
Dictionary with vLLM-compatible configuration
Dictionary with vLLM-compatible speculative configuration
"""
# Currently we only support one proposal method
# Extract speculators configuration
spec_config = config_dict["speculators_config"]
first_method = spec_config.get("proposal_methods")[0]
num_lookahead_tokens = first_method.get("speculative_tokens")
if num_lookahead_tokens is None:
# Currently we only support one proposal method
proposal_methods = spec_config.get("proposal_methods")
if not proposal_methods:
raise ValueError("No proposal methods found in speculators config")
first_method = proposal_methods[0]
num_speculative_tokens = first_method.get("speculative_tokens")
if num_speculative_tokens is None:
raise ValueError(
"Missing 'speculative_tokens' in proposal method. "
f"Got: {first_method}")
# Build base vLLM config
# Build base vLLM speculative configuration
vllm_config = {
"method": config_dict.get("speculators_model_type"),
"num_lookahead_tokens": num_lookahead_tokens,
"num_speculative_tokens": num_speculative_tokens,
"target_model": spec_config.get("verifier")["name_or_path"]
}
vllm_config.update(config_dict["transformer_layer_config"])
# Merge transformer layer configuration if present
transformer_config = config_dict.get("transformer_layer_config", {})
vllm_config.update(transformer_config)
return vllm_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment