Unverified Commit cde8d247 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Spec Decode] Move `SpecDecodeBaseProposer` out of `eagle.py` (#40732)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 4a6dd1c3
...@@ -741,9 +741,9 @@ def test_set_inputs_first_pass_parallel_drafting(): ...@@ -741,9 +741,9 @@ def test_set_inputs_first_pass_parallel_drafting():
@pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@pytest.mark.parametrize("use_distinct_lm_head", [True, False]) @pytest.mark.parametrize("use_distinct_lm_head", [True, False])
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") @mock.patch("vllm.v1.spec_decode.llm_base_proposer.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") @mock.patch("vllm.v1.spec_decode.llm_base_proposer.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model") @mock.patch("vllm.v1.spec_decode.llm_base_proposer.get_model")
def test_load_model( def test_load_model(
mock_get_model, mock_get_model,
mock_get_layers, mock_get_layers,
......
...@@ -61,9 +61,9 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: ...@@ -61,9 +61,9 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
return EagleProposer(vllm_config=vllm_config, device=DEVICE_TYPE) return EagleProposer(vllm_config=vllm_config, device=DEVICE_TYPE)
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") @mock.patch("vllm.v1.spec_decode.llm_base_proposer.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") @mock.patch("vllm.v1.spec_decode.llm_base_proposer.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model") @mock.patch("vllm.v1.spec_decode.llm_base_proposer.get_model")
def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group): def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group):
"""Test MTP-specific model loading with unified model approach.""" """Test MTP-specific model loading with unified model approach."""
......
...@@ -11,7 +11,7 @@ from vllm.forward_context import set_forward_context ...@@ -11,7 +11,7 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer from vllm.v1.spec_decode.llm_base_proposer import SpecDecodeBaseProposer
from vllm.v1.spec_decode.utils import copy_and_expand_dflash_inputs_kernel from vllm.v1.spec_decode.utils import copy_and_expand_dflash_inputs_kernel
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -9,7 +9,7 @@ from vllm.config import VllmConfig ...@@ -9,7 +9,7 @@ from vllm.config import VllmConfig
from vllm.config.utils import replace from vllm.config.utils import replace
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer from vllm.v1.spec_decode.llm_base_proposer import SpecDecodeBaseProposer
logger = init_logger(__name__) logger = init_logger(__name__)
......
This diff is collapsed.
This diff is collapsed.
...@@ -65,16 +65,16 @@ class CPUModelRunner(GPUModelRunner): ...@@ -65,16 +65,16 @@ class CPUModelRunner(GPUModelRunner):
# Speculative decoding fallbacks # Speculative decoding fallbacks
import vllm.v1.sample.rejection_sampler import vllm.v1.sample.rejection_sampler
import vllm.v1.spec_decode.eagle import vllm.v1.spec_decode.llm_base_proposer
import vllm.v1.spec_decode.utils import vllm.v1.spec_decode.utils
vllm.v1.spec_decode.eagle.eagle_prepare_inputs_padded_kernel = ( vllm.v1.spec_decode.llm_base_proposer.eagle_prepare_inputs_padded_kernel = (
cpu_tl.eagle_prepare_inputs_padded_kernel cpu_tl.eagle_prepare_inputs_padded_kernel
) )
vllm.v1.spec_decode.eagle.eagle_prepare_next_token_padded_kernel = ( vllm.v1.spec_decode.llm_base_proposer.eagle_prepare_next_token_padded_kernel = (
cpu_tl.eagle_prepare_next_token_padded_kernel cpu_tl.eagle_prepare_next_token_padded_kernel
) )
vllm.v1.spec_decode.eagle.copy_and_expand_eagle_inputs_kernel = ( vllm.v1.spec_decode.llm_base_proposer.copy_and_expand_eagle_inputs_kernel = (
cpu_tl.copy_and_expand_eagle_inputs_kernel cpu_tl.copy_and_expand_eagle_inputs_kernel
) )
vllm.v1.spec_decode.utils.eagle_step_slot_mapping_metadata_kernel = ( vllm.v1.spec_decode.utils.eagle_step_slot_mapping_metadata_kernel = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment