speculative.py 27.9 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
import hashlib
6
from typing import TYPE_CHECKING, Any, Literal
7

8
from pydantic import Field, SkipValidation, model_validator
9
10
11
12
13
14
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
15
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
16
17
18
19
20
21
22
23
24
25

if TYPE_CHECKING:
    from transformers import PretrainedConfig

    import vllm.model_executor.layers.quantization as me_quant
    from vllm.config import ModelConfig
else:
    PretrainedConfig = Any
    ModelConfig = Any

26
27
28
    me_quant = LazyLoader(
        "model_executor", globals(), "vllm.model_executor.layers.quantization"
    )
29
30
31

logger = init_logger(__name__)

32
33
34
35
36
37
38
39
40
41
42
43
44
SpeculativeMethod = Literal[
    "ngram",
    "eagle",
    "eagle3",
    "medusa",
    "mlp_speculator",
    "draft_model",
    "deepseek_mtp",
    "ernie_mtp",
    "qwen3_next_mtp",
    "mimo_mtp",
    "longcat_flash_mtp",
    "mtp",
45
    "suffix",
46
47
48
49
50
51
52
53
54
]
MTP_MODEL_TYPES = (
    "deepseek_mtp",
    "mimo_mtp",
    "glm4_moe_mtp",
    "ernie_mtp",
    "qwen3_next_mtp",
    "longcat_flash_mtp",
)
55
56
57
58
59
60


@config
@dataclass
class SpeculativeConfig:
    """Configuration for speculative decoding."""
61

62
    enforce_eager: bool | None = None
63
    """Override the default enforce_eager from model_config"""
64
    # General speculative decoding control
65
    num_speculative_tokens: int = Field(default=None, gt=0)
66
67
    """The number of speculative tokens, if provided. It will default to the
    number in the draft model config if present, otherwise, it is required."""
68
    model: str | None = None
69
70
    """The name of the draft model, eagle head, or additional weights, if
    provided."""
71
    method: SpeculativeMethod | None = None
72
73
74
75
76
77
78
    """The name of the speculative method to use. If users provide and set the
    `model` param, the speculative method type will be detected automatically
    if possible, if `model` param is not provided, the method name must be
    provided.

    If using `ngram` method, the related configuration `prompt_lookup_max` and
    `prompt_lookup_min` should be considered."""
79
    draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
80
81
82
83
    """The degree of the tensor parallelism for the draft model. Can only be 1
    or the same as the target model's tensor parallel size."""

    # Draft model configuration
84
    quantization: me_quant.QuantizationMethods | None = None
85
86
87
    """Quantization method that was used to quantize the draft model weights.
    If `None`, we assume the model weights are not quantized. Note that it only
    takes effect when using the draft model-based speculative method."""
88
    max_model_len: int | None = Field(default=None, ge=1)
89
90
    """The maximum model length of the draft model. Used when testing the
    ability to skip speculation for some sequences."""
91
    revision: str | None = None
92
93
94
    """The specific model version to use for the draft model. It can be a
    branch name, a tag name, or a commit id. If unspecified, will use the
    default version."""
95
    code_revision: str | None = None
96
97
98
99
100
    """The specific revision to use for the draft model code on Hugging Face
    Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
    will use the default version."""

    # Advanced control
101
    disable_by_batch_size: int | None = Field(default=None, ge=2)
102
103
    """Disable speculative decoding for new incoming requests when the number
    of enqueued requests is larger than this value, if provided."""
104
105
106
107
108
    disable_padded_drafter_batch: bool = False
    """Disable input padding for speculative decoding. If set to True,
    speculative input batches can contain sequences of different lengths,
    which may only be supported by certain attention backends. This currently
    only affects the EAGLE method of speculation."""
109
110

    # Ngram proposer configuration
111
    prompt_lookup_max: int | None = Field(default=None, ge=1)
112
113
    """Maximum size of ngram token window when using Ngram proposer, required
    when method is set to ngram."""
114
    prompt_lookup_min: int | None = Field(default=None, ge=1)
115
116
117
    """Minimum size of ngram token window when using Ngram proposer, if
    provided. Defaults to 1."""

118
    speculative_token_tree: str | None = None
119
120
121
122
123
    """Specifies the tree structure for speculative token generation.
    """
    # required configuration params passed from engine
    target_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the target model."""
124
    target_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
125
126
127
128
129
    """The parallel configuration for the target model."""

    # params generated in the post-init stage
    draft_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the draft model initialized internal."""
130
    draft_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
131
132
    """The parallel configuration for the draft model initialized internal."""

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    # Suffix decoding configuration
    suffix_decoding_max_tree_depth: int = 24
    """The maximum depth of the suffix decoding global and prompt trees. The
    tree depth limits the sum of the prefix match and speculation lengths."""

    suffix_decoding_max_cached_requests: int = 10000
    """The maximum number of requests to cache in the global suffix tree. If
    exceeded, will trigger eviction in FIFO order. If set to 0, the global
    suffix tree is disabled and past responses are not cached (prompt trees
    are still used)."""

    suffix_decoding_max_spec_factor: float = 1.0
    """The maximum spec factor for suffix decoding. The spec factor controls
    speculation lengths based on the prefix match length: max_spec_tokens =
    max_spec_factor * prefix_match_length."""

    suffix_decoding_min_token_prob: float = 0.1
    """The minimum token probability for suffix decoding. Will only speculate
    tokens with estimated probability (based on frequency counts) greater than
    or equal to this value."""

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        # Eagle3 affects the computation graph because it returns intermediate
        # hidden states in addition to the final hidden state.
        factors.append(self.method == "eagle3")
170
        hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
171
172
173
174
        return hash_str

    @staticmethod
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
175
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
176
177
178
            hf_config.model_type = "deepseek_mtp"
        if hf_config.model_type == "deepseek_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
179
180
181
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
            )
182
183
184
185

        if hf_config.architectures[0] == "MiMoForCausalLM":
            hf_config.model_type = "mimo_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
186
187
188
189
190
191
192
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["MiMoMTPModel"],
                }
            )
193
194
195
196

        if hf_config.architectures[0] == "Glm4MoeForCausalLM":
            hf_config.model_type = "glm4_moe_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
197
198
199
200
201
202
203
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeMTPModel"],
                }
            )
204
205
206
207
208

        if hf_config.model_type == "ernie4_5_moe":
            hf_config.model_type = "ernie_mtp"
        if hf_config.model_type == "ernie_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
209
210
211
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
            )
212
213
214
215
216

        if hf_config.model_type == "qwen3_next":
            hf_config.model_type = "qwen3_next_mtp"
        if hf_config.model_type == "qwen3_next_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
217
218
219
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
            )
XuruiYang's avatar
XuruiYang committed
220
221
222
        if hf_config.model_type == "longcat_flash":
            hf_config.model_type = "longcat_flash_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
223
224
225
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
            )
226
227
228
229
230
231
232
233
234
235
236
237

        return hf_config

    def __post_init__(self):
        # Note: "method" is a new parameter that helps to extend the
        # configuration of non-model-based proposers, and the "model" parameter
        # will be used to set the draft model, eagle head, or additional weight
        # when needed. If users do not specify "method", the speculative method
        # will be detected automatically if possible. If the speculative method
        # can not be detected, it will be considered as the "draft_model" by
        # default.

238
        if self.method in MTP_MODEL_TYPES:
239
240
241
            logger.warning(
                "method `%s` is deprecated and replaced with mtp.", self.method
            )
242
243
            self.method = "mtp"

244
        if self.model is None and self.num_speculative_tokens is not None:
245
            if self.method == "mtp":
246
247
                if self.target_model_config is None:
                    raise ValueError("target_model_config must be present for mtp")
248
                if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
249
250
251
                    # FIXME(luccafong): cudgraph with v32 MTP is not supported,
                    # remove this when the issue is fixed.
                    self.enforce_eager = True
252
253
254
255
256
257
258
259
                # use the draft model from the same model:
                self.model = self.target_model_config.model
                # Align the quantization of draft model for cases such as
                # --quantization fp8 with a bf16 checkpoint.
                if not self.quantization:
                    self.quantization = self.target_model_config.quantization
            elif self.method in ("ngram", "[ngram]"):
                self.model = "ngram"
260
261
            elif self.method == "suffix":
                self.model = "suffix"
262
            else:
263
                raise ValueError(
264
265
                    "num_speculative_tokens was provided but without speculative model."
                )
266
267
268

        # Automatically configure the method for ngram when "model" is used
        # instead of "method"
269
270
271
        if self.method is None and (
            self.model is not None and self.model in ("ngram", "[ngram]")
        ):
272
273
274
275
276
277
            self.method = "ngram"

        if self.method in ("ngram", "[ngram]"):
            # Unified to "ngram" internally
            self.method = "ngram"
            # Set default values if not provided
278
            if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
279
280
281
282
                # TODO(woosuk): Tune these values. They are arbitrarily chosen.
                self.prompt_lookup_min = 5
                self.prompt_lookup_max = 5
            elif self.prompt_lookup_min is None:
283
284
285
286
287
                if self.prompt_lookup_max is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
288
289
                self.prompt_lookup_min = self.prompt_lookup_max
            elif self.prompt_lookup_max is None:
290
291
292
293
294
                if self.prompt_lookup_min is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
295
296
297
298
299
300
                self.prompt_lookup_max = self.prompt_lookup_min

            # Validate values
            if self.prompt_lookup_min > self.prompt_lookup_max:
                raise ValueError(
                    f"prompt_lookup_min={self.prompt_lookup_min} must "
301
302
                    f"be <= prompt_lookup_max={self.prompt_lookup_max}"
                )
303
304
305
306
307
308

            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            self.draft_model_config = self.target_model_config
            self.draft_parallel_config = self.target_parallel_config
309
310
        elif self.method == "suffix":
            self._validate_suffix_decoding()
311
312
313
314
315
316
317
318
        else:
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if self.model is not None:
                # TODO: Move this import to the top once `ModelConfig`
                # lives in `vllm.config.model`.
                from vllm.config import ModelConfig
319

320
321
322
323
324
                self.draft_model_config = ModelConfig(
                    model=self.model,
                    runner="draft",
                    tokenizer=self.target_model_config.tokenizer,
                    tokenizer_mode=self.target_model_config.tokenizer_mode,
325
326
327
                    trust_remote_code=self.target_model_config.trust_remote_code,
                    allowed_local_media_path=self.target_model_config.allowed_local_media_path,
                    allowed_media_domains=self.target_model_config.allowed_media_domains,
328
329
330
331
                    dtype=self.target_model_config.dtype,
                    seed=self.target_model_config.seed,
                    revision=self.revision,
                    code_revision=self.code_revision,
332
333
                    tokenizer_revision=self.target_model_config.tokenizer_revision,
                    spec_target_max_model_len=self.target_model_config.max_model_len,
334
335
336
337
338
339
340
                    quantization=self.quantization,
                    enforce_eager=self.target_model_config.enforce_eager,
                    max_logprobs=self.target_model_config.max_logprobs,
                    hf_overrides=SpeculativeConfig.hf_config_override,
                )

                # Automatically detect the method
341
                if self.method in ("eagle", "eagle3"):
342
343
344
345
346
347
348
349
350
351
352
                    pass
                # examples:
                # yuhuili/EAGLE-LLaMA3-Instruct-8B
                # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
                # AngelSlim/Qwen3-8B_eagle3
                elif "eagle-" in self.draft_model_config.model.lower():
                    self.method = "eagle"
                elif "eagle3" in self.draft_model_config.model.lower():
                    self.method = "eagle3"
                elif self.draft_model_config.hf_config.model_type == "medusa":
                    self.method = "medusa"
353
                elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
354
                    self.method = "mlp_speculator"
355
                elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
356
                    self.method = "mtp"
357
358
                    if self.num_speculative_tokens > 1:
                        logger.warning(
359
360
361
362
363
364
365
                            "Enabling num_speculative_tokens > 1 will run"
                            "multiple times of forward on same MTP layer"
                            ",which may result in lower acceptance rate"
                        )
                elif self.draft_model_config.hf_config.model_type in (
                    "longcat_flash_mtp"
                ):
XuruiYang's avatar
XuruiYang committed
366
367
368
                    self.method = "longcat_flash_mtp"
                    if self.num_speculative_tokens > 1:
                        logger.warning(
369
370
371
372
                            "LongCat MTP models only have "
                            "one layer. Might need some code changes "
                            "to support multiple layers."
                        )
373
374
375
376
377
378
                else:
                    self.method = "draft_model"
                    raise NotImplementedError(
                        "Speculative decoding with draft model is not "
                        "supported yet. Please consider using other "
                        "speculative decoding methods such as ngram, medusa, "
379
380
                        "eagle, or mtp."
                    )
381
382
383

                # Replace hf_config for EAGLE draft_model
                if self.method in ("eagle", "eagle3"):
384
385
                    from vllm.transformers_utils.configs import SpeculatorsConfig
                    from vllm.transformers_utils.configs.eagle import EAGLEConfig
386

387
388
389
390
                    if isinstance(
                        self.draft_model_config.hf_config,
                        (EAGLEConfig, SpeculatorsConfig),
                    ):
391
392
393
394
395
                        pass
                    else:
                        eagle_config = EAGLEConfig(
                            self.draft_model_config.hf_config,
                            method=self.method,
396
397
                            model_type="eagle",
                        )
398
399
                        self.draft_model_config.hf_config = eagle_config

400
401
402
403
404
405
                if self.num_speculative_tokens is not None and hasattr(
                    self.draft_model_config.hf_config, "num_lookahead_tokens"
                ):
                    self.draft_model_config.hf_config.num_lookahead_tokens = (
                        self.num_speculative_tokens
                    )
406

407
408
409
                n_predict = getattr(
                    self.draft_model_config.hf_config, "n_predict", None
                )
410
411
412
413
                if n_predict is not None:
                    if self.num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        self.num_speculative_tokens = n_predict
414
415
416
417
                    elif (
                        self.num_speculative_tokens > n_predict
                        and self.num_speculative_tokens % n_predict != 0
                    ):
418
419
420
                        # Ensure divisibility for MTP module reuse.
                        raise ValueError(
                            f"num_speculative_tokens:{self.num_speculative_tokens}"
421
422
                            f" must be divisible by {n_predict=}"
                        )
423
424
425

                if self.speculative_token_tree is None:
                    # Generate chain of tokens.
426
427
428
                    self.speculative_token_tree = str(
                        [(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
                    )
429
430
                else:
                    # Sort the token tree breadth-first.
431
                    tree_choices = ast.literal_eval(self.speculative_token_tree)
432
                    self.speculative_token_tree = str(
433
434
                        sorted(tree_choices, key=lambda t: (len(t), t))
                    )
435

436
                self.draft_tensor_parallel_size = (
437
438
439
                    SpeculativeConfig._verify_and_get_draft_tp(
                        self.target_parallel_config,
                        self.draft_tensor_parallel_size,
440
441
                        self.draft_model_config.hf_config,
                    )
442
443
444
445
446
447
448
                )

                self.draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        self.max_model_len,
                        self.draft_model_config.max_model_len,
                        self.target_model_config.max_model_len,
449
450
                    )
                )
451
452
453

                self.draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
454
455
456
                        self.target_parallel_config, self.draft_tensor_parallel_size
                    )
                )
457
        return self
458

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    def _validate_suffix_decoding(self):
        if not has_arctic_inference():
            raise ImportError(
                "Arctic Inference is required for suffix decoding. "
                "Install via `pip install arctic-inference==0.1.0`."
            )
        if self.num_speculative_tokens is None:
            # Suffix decoding decides the actual number of speculative tokens
            # dynamically and treats num_speculative_tokens as a maximum limit.
            self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
            logger.warning(
                "Defaulted num_speculative_tokens to %s for suffix decoding.",
                self.num_speculative_tokens,
            )
        # Validate values
        if self.suffix_decoding_max_tree_depth < 1:
            raise ValueError(
                f"suffix_decoding_max_tree_depth="
                f"{self.suffix_decoding_max_tree_depth} must be >= 1"
            )
        if self.suffix_decoding_max_cached_requests < 0:
            raise ValueError(
                f"suffix_decoding_max_cached_requests="
                f"{self.suffix_decoding_max_cached_requests} must be >= 0"
            )
        if self.suffix_decoding_max_spec_factor < 0:
            raise ValueError(
                f"suffix_decoding_max_spec_factor="
                f"{self.suffix_decoding_max_spec_factor} must be >= 0"
            )
        if not 0 <= self.suffix_decoding_min_token_prob <= 1:
            raise ValueError(
                f"suffix_decoding_min_token_prob="
                f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
            )

495
496
    @staticmethod
    def _maybe_override_draft_max_model_len(
497
        speculative_max_model_len: int | None,
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:
            if speculative_max_model_len > draft_max_model_len:
515
516
517
518
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {draft_max_model_len=}"
                )
519
520

            if speculative_max_model_len > target_max_model_len:
521
522
523
524
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {target_max_model_len=}"
                )
525
526
527
528
529
530
531
532
533
534

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

    @staticmethod
    def _verify_and_get_draft_tp(
535
        target_parallel_config: ParallelConfig,
536
        speculative_draft_tensor_parallel_size: int | None,
537
538
        draft_hf_config: PretrainedConfig,
    ) -> int:
539
540
541
542
543
544
545
546
547
548
549
550
551
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
        """
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
        if speculative_draft_tensor_parallel_size is None:
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "%s cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1",
552
553
                        draft_hf_config.model_type,
                    )
554
            else:
555
                speculative_draft_tensor_parallel_size = (
556
                    target_parallel_config.tensor_parallel_size
557
                )
558
        elif speculative_draft_tensor_parallel_size not in (
559
560
561
            1,
            target_parallel_config.tensor_parallel_size,
        ):
562
563
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be "
564
565
                f"other value than 1 or target model tensor_parallel_size"
            )
566
567
568
569
570
571
572
573
574
575
576
577
        return speculative_draft_tensor_parallel_size

    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
        draft_parallel_config = ParallelConfig(
578
            pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
579
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
580
581
582
583
            distributed_executor_backend=target_parallel_config.distributed_executor_backend,
            max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
            ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
584
585
586
587
588
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

589
    @model_validator(mode="after")
590
591
592
593
594
    def _verify_args(self) -> Self:
        if self.num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative model unless the draft model config contains an "
595
596
                "n_predict parameter."
            )
597
598

        if self.num_speculative_tokens <= 0:
599
600
601
602
            raise ValueError(
                "Expected num_speculative_tokens to be greater "
                f"than zero ({self.num_speculative_tokens})."
            )
603
604
605

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
606
607
                self.draft_parallel_config
            )
608

609
610
611
612
613
614
        if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2:
            raise ValueError(
                "Expect the batch size threshold of disabling "
                "speculative decoding is > 1, but got "
                f"{self.disable_by_batch_size=}"
            )
615

616
        eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
617
618
619
620
621
622
623
624
        if (
            self.method == "eagle3"
            and self.target_model_config
            and not any(
                supported_model in self.target_model_config.hf_text_config.model_type
                for supported_model in eagle3_target_supported
            )
        ):
625
626
            raise ValueError(
                f"Eagle3 is only supported for {eagle3_target_supported} models. "  # noqa: E501
627
628
                f"Got {self.target_model_config.hf_text_config.model_type=}"
            )
629
630
631
632
633
634
635
636
637
638
639
640
641
642

        return self

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def use_eagle(self) -> bool:
643
        return self.method in ("eagle", "eagle3", "mtp")
644
645
646

    def __repr__(self) -> str:
        method = self.method
647
        model = None if method in ("ngram", "suffix") else self.draft_model_config.model
648
649
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"