Unverified Commit a2b877df authored by 22quinn's avatar 22quinn Committed by GitHub
Browse files

[Bugfix] Lazy import NgramProposer in GPU model runner (#32821)


Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 35fb0b86
...@@ -150,7 +150,6 @@ from vllm.v1.spec_decode.draft_model import DraftModelProposer ...@@ -150,7 +150,6 @@ from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
...@@ -185,6 +184,7 @@ from .utils import ( ...@@ -185,6 +184,7 @@ from .utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -439,13 +439,15 @@ class GPUModelRunner( ...@@ -439,13 +439,15 @@ class GPUModelRunner(
# layers in the draft model. # layers in the draft model.
if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config and get_pp_group().is_last_rank:
self.drafter: ( self.drafter: (
NgramProposer NgramProposer # noqa: F823
| SuffixDecodingProposer | SuffixDecodingProposer
| EagleProposer | EagleProposer
| DraftModelProposer | DraftModelProposer
| MedusaProposer | MedusaProposer
) )
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
self.drafter = NgramProposer(self.vllm_config) self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.uses_draft_model(): elif self.speculative_config.uses_draft_model():
self.drafter = DraftModelProposer( self.drafter = DraftModelProposer(
...@@ -3848,6 +3850,8 @@ class GPUModelRunner( ...@@ -3848,6 +3850,8 @@ class GPUModelRunner(
spec_config = self.speculative_config spec_config = self.speculative_config
assert spec_config is not None assert spec_config is not None
if spec_config.method == "ngram": if spec_config.method == "ngram":
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, NgramProposer) assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment