speculative.py 39 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
5
import copy
6
from typing import TYPE_CHECKING, Any, Literal, get_args
7

8
from pydantic import Field, SkipValidation, model_validator
9
10
from typing_extensions import Self

11
from vllm.config import LoadConfig
12
from vllm.config.kernel import MoEBackend
13
from vllm.config.model import ModelConfig
14
15
16
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
17
from vllm.transformers_utils.config import get_hf_text_config
18
from vllm.utils.hashing import safe_hash
19
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
20
21
22
23
24
25
26
27

if TYPE_CHECKING:
    from transformers import PretrainedConfig

    import vllm.model_executor.layers.quantization as me_quant
else:
    PretrainedConfig = Any

28
29
30
    me_quant = LazyLoader(
        "model_executor", globals(), "vllm.model_executor.layers.quantization"
    )
31
32
33

logger = init_logger(__name__)

34
MTPModelTypes = Literal[
35
36
37
    "deepseek_mtp",
    "mimo_mtp",
    "glm4_moe_mtp",
38
    "glm4_moe_lite_mtp",
39
    "glm_ocr_mtp",
40
    "ernie_mtp",
41
    "nemotron_h_mtp",
Kyungmin Lee's avatar
Kyungmin Lee committed
42
    "exaone_moe_mtp",
Kyungmin Lee's avatar
Kyungmin Lee committed
43
    "exaone4_5_mtp",
44
    "qwen3_next_mtp",
45
    "qwen3_5_mtp",
46
    "longcat_flash_mtp",
47
    "mtp",
48
    "pangu_ultra_moe_mtp",
csy0225's avatar
csy0225 committed
49
    "step3p5_mtp",
50
    "hy_v3_mtp",
51
]
52
NgramGPUTypes = Literal["ngram_gpu"]
53
54
55
56
DFlashModelTypes = Literal["dflash"]
EagleModelTypes = Literal[
    "eagle", "eagle3", "extract_hidden_states", MTPModelTypes, DFlashModelTypes
]
57
58
59
60
61
62
63
SpeculativeMethod = Literal[
    "ngram",
    "medusa",
    "mlp_speculator",
    "draft_model",
    "suffix",
    EagleModelTypes,
64
    NgramGPUTypes,
65
]
66
RejectionSampleMethod = Literal["strict", "probabilistic", "synthetic"]
67
68
69
70
71


@config
class SpeculativeConfig:
    """Configuration for speculative decoding."""
72

73
    enforce_eager: bool | None = None
74
    """Override the default enforce_eager from model_config"""
75
    # General speculative decoding control
76
    num_speculative_tokens: int = Field(default=None, gt=0)  # type: ignore[assignment]
77
78
    """The number of speculative tokens, if provided. It will default to the
    number in the draft model config if present, otherwise, it is required."""
79
    model: str | None = None
80
81
    """The name of the draft model, eagle head, or additional weights, if
    provided."""
82
    method: SpeculativeMethod | None = None
83
84
85
86
87
88
89
    """The name of the speculative method to use. If users provide and set the
    `model` param, the speculative method type will be detected automatically
    if possible, if `model` param is not provided, the method name must be
    provided.

    If using `ngram` method, the related configuration `prompt_lookup_max` and
    `prompt_lookup_min` should be considered."""
90
    draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
91
92
    """The degree of the tensor parallelism for the draft model. Can only be 1
    or the same as the target model's tensor parallel size."""
93
94
95
    tensor_parallel_size: int | None = None
    """Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
    warn users when they mistakenly provide the wrong argument."""
96
97

    # Draft model configuration
98
    quantization: me_quant.QuantizationMethods | str | None = None
99
100
101
    """Quantization method that was used to quantize the draft model weights.
    If `None`, we assume the model weights are not quantized. Note that it only
    takes effect when using the draft model-based speculative method."""
102
103
104
105
106
    moe_backend: MoEBackend | None = None
    """MoE backend to use for the draft model. When `None`, the draft model
    inherits the target model's `--moe-backend` setting. Useful when the
    drafter and generator require different MoE kernels (e.g. quantized
    generator with unquantized drafter)."""
107
    max_model_len: int | None = Field(default=None, ge=1)
108
109
    """The maximum model length of the draft model. Used when testing the
    ability to skip speculation for some sequences."""
110
    revision: str | None = None
111
112
113
    """The specific model version to use for the draft model. It can be a
    branch name, a tag name, or a commit id. If unspecified, will use the
    default version."""
114
    code_revision: str | None = None
115
116
117
118
119
    """The specific revision to use for the draft model code on Hugging Face
    Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
    will use the default version."""

    # Advanced control
120
121
122
123
124
    disable_padded_drafter_batch: bool = False
    """Disable input padding for speculative decoding. If set to True,
    speculative input batches can contain sequences of different lengths,
    which may only be supported by certain attention backends. This currently
    only affects the EAGLE method of speculation."""
125
126
127
128
129
    use_local_argmax_reduction: bool = False
    """Use vocab-parallel local argmax instead of all-gathering full logits
    for draft token generation. Reduces communication from O(vocab_size) to
    O(2 * tp_size) per token. Only applies to greedy draft selection in
    non-tree speculation."""
130
131

    # Ngram proposer configuration
132
    prompt_lookup_max: int | None = Field(default=None, ge=1)
133
134
    """Maximum size of ngram token window when using Ngram proposer, required
    when method is set to ngram."""
135
    prompt_lookup_min: int | None = Field(default=None, ge=1)
136
137
138
    """Minimum size of ngram token window when using Ngram proposer, if
    provided. Defaults to 1."""

139
    # Alternative drafting strategies
140
    speculative_token_tree: str | None = None
141
142
    """Specifies the tree structure for speculative token generation.
    """
143
144
145
146
147
148
    parallel_drafting: bool = False
    """Enable parallel drafting, where all speculative tokens are generated
    in parallel rather than sequentially. This can improve performance but
    requires the speculative model be trained to support parallel drafting.
    Only compatible with EAGLE and draft model methods."""

149
150
151
    # required configuration params passed from engine
    target_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the target model."""
152
    target_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
153
154
155
156
157
    """The parallel configuration for the target model."""

    # params generated in the post-init stage
    draft_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the draft model initialized internal."""
158
    draft_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
159
160
    """The parallel configuration for the draft model initialized internal."""

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    # Suffix decoding configuration
    suffix_decoding_max_tree_depth: int = 24
    """The maximum depth of the suffix decoding global and prompt trees. The
    tree depth limits the sum of the prefix match and speculation lengths."""

    suffix_decoding_max_cached_requests: int = 10000
    """The maximum number of requests to cache in the global suffix tree. If
    exceeded, will trigger eviction in FIFO order. If set to 0, the global
    suffix tree is disabled and past responses are not cached (prompt trees
    are still used)."""

    suffix_decoding_max_spec_factor: float = 1.0
    """The maximum spec factor for suffix decoding. The spec factor controls
    speculation lengths based on the prefix match length: max_spec_tokens =
    max_spec_factor * prefix_match_length."""

    suffix_decoding_min_token_prob: float = 0.1
    """The minimum token probability for suffix decoding. Will only speculate
    tokens with estimated probability (based on frequency counts) greater than
    or equal to this value."""

182
183
184
185
    draft_load_config: LoadConfig | None = None
    """Load config for the draft model. If not specified, will use the load
    config from the target model."""

186
187
188
189
190
191
    rejection_sample_method: RejectionSampleMethod = "strict"
    """Whether to use strict (target and draft sampled tokens match exactly)
    or probabilistic rejection sampling. Both respect the target model
    distribution, but the latter yields a higher acceptance rate at the cost
    of more memory to cache draft logits."""

192
193
194
195
196
197
198
    synthetic_acceptance_rate: float | None = None
    """Average acceptance rate for synthetic rejection sampling. Draft
    tokens are accepted with a position-dependent probability that decays
    geometrically, calibrated so that the mean rate across all speculative
    positions equals this value. Only used when rejection_sample_method
    is 'synthetic'. Must be in [0, 1]."""

199
200
201
202
203
204
205
206
207
208
209
210
211
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
212
213
        # Eagle3 and extract_hidden_states affect the computation graph because
        # they return intermediate hidden states in addition to the final hidden state.
214
215
216
217
218
        uses_aux_hidden_states = self.method in (
            "eagle3",
            "extract_hidden_states",
            "dflash",
        )
219
220
221
222
223
224
225
226
227
228
229
230
231
        factors.append(uses_aux_hidden_states)

        # The specific layers used also affect the computation graph
        if uses_aux_hidden_states and self.draft_model_config is not None:
            layer_ids = getattr(
                self.draft_model_config.hf_config,
                "eagle_aux_hidden_state_layer_ids",
                None,
            )
            if layer_ids is not None:
                # Convert to tuple to make it hashable
                factors.append(tuple(layer_ids))

232
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
233
234
235
236
        return hash_str

    @staticmethod
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
237
        initial_architecture = hf_config.architectures[0]
Jee Jee Li's avatar
Jee Jee Li committed
238
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
239
240
241
            hf_config.model_type = "deepseek_mtp"
        if hf_config.model_type == "deepseek_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
242
243
244
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
            )
245
246
247
248
249
250
251
        if hf_config.model_type in ("pangu_ultra_moe"):
            hf_config.model_type = "pangu_ultra_moe_mtp"
        if hf_config.model_type == "pangu_ultra_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
            )
252
253
254
255

        if hf_config.architectures[0] == "MiMoForCausalLM":
            hf_config.model_type = "mimo_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
256
257
258
259
260
261
262
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["MiMoMTPModel"],
                }
            )
263
264
265
266

        if hf_config.architectures[0] == "Glm4MoeForCausalLM":
            hf_config.model_type = "glm4_moe_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
267
268
269
270
271
272
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeMTPModel"],
                }
            )
273

274
275
276
277
278
279
280
281
282
283
284
        if hf_config.architectures[0] == "Glm4MoeLiteForCausalLM":
            hf_config.model_type = "glm4_moe_lite_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeLiteMTPModel"],
                }
            )

285
286
287
288
289
290
291
292
293
294
295
        if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
            hf_config.model_type = "glm_ocr_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["GlmOcrMTPModel"],
                }
            )

296
297
298
299
        if hf_config.model_type == "ernie4_5_moe":
            hf_config.model_type = "ernie_mtp"
        if hf_config.model_type == "ernie_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
300
301
302
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
            )
303

304
305
306
307
        if hf_config.architectures[0] == "NemotronH_Super_Omni_Reasoning_V3":
            # Promote VLM's text_config so MTP detection below fires correctly
            hf_config = hf_config.text_config

308
        if (
309
            hf_config.model_type in {"nemotron_h", "nemotron_h_puzzle"}
310
311
312
313
314
315
316
317
318
319
320
            and hasattr(hf_config, "num_nextn_predict_layers")
            and hf_config.num_nextn_predict_layers > 0
        ):
            # Check if this is an MTP variant
            hf_config.model_type = "nemotron_h_mtp"
        if hf_config.model_type == "nemotron_h_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
            )

321
322
323
324
        if hf_config.model_type == "qwen3_next":
            hf_config.model_type = "qwen3_next_mtp"
        if hf_config.model_type == "qwen3_next_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
325
326
327
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
            )
Kyungmin Lee's avatar
Kyungmin Lee committed
328
329
330
331
332
333
334
335

        if hf_config.model_type == "exaone_moe":
            hf_config.model_type = "exaone_moe_mtp"
        if hf_config.model_type == "exaone_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
            )
Kyungmin Lee's avatar
Kyungmin Lee committed
336
337
338
339
340
341
342
        if "exaone4_5" in hf_config.model_type:
            hf_config.model_type = "exaone4_5_mtp"
        if hf_config.model_type == "exaone4_5_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Exaone4_5_MTP"]}
            )
343
344
345
346
347
348
349
350
351
352
        if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
            is_moe = hf_config.model_type == "qwen3_5_moe"
            hf_config.model_type = "qwen3_5_mtp"
            n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
                }
            )
XuruiYang's avatar
XuruiYang committed
353
354
355
        if hf_config.model_type == "longcat_flash":
            hf_config.model_type = "longcat_flash_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
356
357
358
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
            )
359

csy0225's avatar
csy0225 committed
360
361
362
363
364
        if hf_config.model_type == "step3p5":
            hf_config.model_type = "step3p5_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
            hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})

365
366
367
        if initial_architecture == "MistralLarge3ForCausalLM":
            hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})

368
369
370
371
372
373
374
        if hf_config.model_type == "hy_v3":
            hf_config.model_type = "hy_v3_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["HYV3MTPModel"]}
            )

375
376
377
378
379
380
381
382
383
384
385
        return hf_config

    def __post_init__(self):
        # Note: "method" is a new parameter that helps to extend the
        # configuration of non-model-based proposers, and the "model" parameter
        # will be used to set the draft model, eagle head, or additional weight
        # when needed. If users do not specify "method", the speculative method
        # will be detected automatically if possible. If the speculative method
        # can not be detected, it will be considered as the "draft_model" by
        # default.

386
387
388
389
390
391
392
        # infer method from user args
        if self.method is None:
            if self.model in ("ngram", "[ngram]"):
                self.method = "ngram"
            else:
                self.method = "draft_model"

393
        if self.method in get_args(MTPModelTypes) and self.method != "mtp":
394
395
396
            logger.warning(
                "method `%s` is deprecated and replaced with mtp.", self.method
            )
397
398
            self.method = "mtp"

399
        if self.model is None and self.num_speculative_tokens is not None:
400
            if self.method == "mtp":
401
402
                if self.target_model_config is None:
                    raise ValueError("target_model_config must be present for mtp")
403
                if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
404
                    # FIXME(luccafong): cudagraph with v32 MTP is not supported,
405
406
                    # remove this when the issue is fixed.
                    self.enforce_eager = True
407
408
409
410
411
412
413
414
                # use the draft model from the same model:
                self.model = self.target_model_config.model
                # Align the quantization of draft model for cases such as
                # --quantization fp8 with a bf16 checkpoint.
                if not self.quantization:
                    self.quantization = self.target_model_config.quantization
            elif self.method in ("ngram", "[ngram]"):
                self.model = "ngram"
415
416
            elif self.method == "ngram_gpu":
                self.model = "ngram_gpu"
417
418
            elif self.method == "suffix":
                self.model = "suffix"
419
420
            elif self.method == "extract_hidden_states":
                self.model = "extract_hidden_states"
421
            else:
422
                raise ValueError(
423
424
                    "num_speculative_tokens was provided but without speculative model."
                )
425
426
427

        if self.method in ("ngram", "[ngram]"):
            self.method = "ngram"
428
429

        if self.method in ("ngram", "ngram_gpu"):
430
            # Set default values if not provided
431
            if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
432
433
434
435
                # TODO(woosuk): Tune these values. They are arbitrarily chosen.
                self.prompt_lookup_min = 5
                self.prompt_lookup_max = 5
            elif self.prompt_lookup_min is None:
436
437
438
439
440
                if self.prompt_lookup_max is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
441
442
                self.prompt_lookup_min = self.prompt_lookup_max
            elif self.prompt_lookup_max is None:
443
444
445
446
447
                if self.prompt_lookup_min is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
448
449
450
451
452
453
                self.prompt_lookup_max = self.prompt_lookup_min

            # Validate values
            if self.prompt_lookup_min > self.prompt_lookup_max:
                raise ValueError(
                    f"prompt_lookup_min={self.prompt_lookup_min} must "
454
455
                    f"be <= prompt_lookup_max={self.prompt_lookup_max}"
                )
456
457
458
459
460
461

            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            self.draft_model_config = self.target_model_config
            self.draft_parallel_config = self.target_parallel_config
462
463
        elif self.method == "suffix":
            self._validate_suffix_decoding()
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        elif self.method == "extract_hidden_states":
            from vllm.transformers_utils.configs.extract_hidden_states import (
                ExtractHiddenStatesConfig,
            )

            # ExtractHiddenStatesModel is instantiated manually in load_model()
            # We just need to store the target model config for KV cache shape info
            self.model = "extract_hidden_states"
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if hasattr(self.draft_model_config, "hf_config"):
                hf_config = self.draft_model_config.hf_config.to_dict()
            elif (
                isinstance(self.draft_model_config, dict)
                and "hf_config" in self.draft_model_config
            ):
                hf_config = self.draft_model_config["hf_config"]
            else:
                hf_config = {}

            self.draft_model_config = copy.copy(self.target_model_config)
            self.draft_model_config.hf_config = ExtractHiddenStatesConfig(
                self.draft_model_config.hf_config, **hf_config
            )
            self.update_arch_()
            self.draft_parallel_config = self.target_parallel_config

492
493
494
495
496
497
498
499
        else:
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if self.model is not None:
                self.draft_model_config = ModelConfig(
                    model=self.model,
                    runner="draft",
500
501
                    tokenizer=self.target_model_config.tokenizer,
                    tokenizer_mode=self.target_model_config.tokenizer_mode,
502
                    trust_remote_code=self.target_model_config.trust_remote_code,
503
504
                    allowed_local_media_path=self.target_model_config.allowed_local_media_path,
                    allowed_media_domains=self.target_model_config.allowed_media_domains,
505
506
507
508
                    dtype=self.target_model_config.dtype,
                    seed=self.target_model_config.seed,
                    revision=self.revision,
                    code_revision=self.code_revision,
509
                    tokenizer_revision=self.target_model_config.tokenizer_revision,
510
                    spec_target_max_model_len=self.target_model_config.max_model_len,
511
512
513
514
                    quantization=self.quantization,
                    enforce_eager=self.target_model_config.enforce_eager,
                    max_logprobs=self.target_model_config.max_logprobs,
                    hf_overrides=SpeculativeConfig.hf_config_override,
515
                    config_format=self.target_model_config.config_format,
516
517
518
                )

                # Automatically detect the method
519
                if self.method in ("eagle", "eagle3", "dflash"):
520
521
522
523
524
525
526
527
528
                    pass
                # examples:
                # yuhuili/EAGLE-LLaMA3-Instruct-8B
                # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
                # AngelSlim/Qwen3-8B_eagle3
                elif "eagle-" in self.draft_model_config.model.lower():
                    self.method = "eagle"
                elif "eagle3" in self.draft_model_config.model.lower():
                    self.method = "eagle3"
529
530
                elif "dflash" in self.draft_model_config.model.lower():
                    self.method = "dflash"
531
532
                elif self.draft_model_config.hf_config.model_type == "medusa":
                    self.method = "medusa"
533
                elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
534
                    self.method = "mlp_speculator"
535
536
537
                elif self.draft_model_config.hf_config.model_type in get_args(
                    MTPModelTypes
                ):
538
                    self.method = "mtp"
539
540
                    if self.num_speculative_tokens > 1:
                        logger.warning(
541
                            "Enabling num_speculative_tokens > 1 will run "
542
543
544
545
546
547
                            "multiple times of forward on same MTP layer"
                            ",which may result in lower acceptance rate"
                        )
                elif self.draft_model_config.hf_config.model_type in (
                    "longcat_flash_mtp"
                ):
XuruiYang's avatar
XuruiYang committed
548
549
550
                    self.method = "longcat_flash_mtp"
                    if self.num_speculative_tokens > 1:
                        logger.warning(
551
552
553
554
                            "LongCat MTP models only have "
                            "one layer. Might need some code changes "
                            "to support multiple layers."
                        )
555
556
                elif self.method == "draft_model":
                    pass
557
558
                else:
                    raise NotImplementedError(
559
                        f"Unsupported speculative method: '{self.method}'"
560
                    )
561
562

                # Replace hf_config for EAGLE draft_model
563
                if self.method in ("eagle", "eagle3", "dflash"):
564
                    from vllm.transformers_utils.configs.eagle import EAGLEConfig
565
566
567
                    from vllm.transformers_utils.configs.speculators import (
                        SpeculatorsConfig,
                    )
568

569
570
571
572
                    if isinstance(
                        self.draft_model_config.hf_config,
                        (EAGLEConfig, SpeculatorsConfig),
                    ):
573
574
575
576
577
                        pass
                    else:
                        eagle_config = EAGLEConfig(
                            self.draft_model_config.hf_config,
                            method=self.method,
578
579
                            model_type="eagle",
                        )
580
                        self.draft_model_config.hf_config = eagle_config
581
                        self.update_arch_()
582

583
584
585
                if self.method == "dflash":
                    self.parallel_drafting = True

586
587
588
589
590
591
                if self.num_speculative_tokens is not None and hasattr(
                    self.draft_model_config.hf_config, "num_lookahead_tokens"
                ):
                    self.draft_model_config.hf_config.num_lookahead_tokens = (
                        self.num_speculative_tokens
                    )
592

593
594
595
                n_predict = getattr(
                    self.draft_model_config.hf_config, "n_predict", None
                )
596
597
598
599
                if n_predict is not None:
                    if self.num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        self.num_speculative_tokens = n_predict
600
601
602
603
                    elif (
                        self.num_speculative_tokens > n_predict
                        and self.num_speculative_tokens % n_predict != 0
                    ):
604
605
606
                        # Ensure divisibility for MTP module reuse.
                        raise ValueError(
                            f"num_speculative_tokens:{self.num_speculative_tokens}"
607
608
                            f" must be divisible by {n_predict=}"
                        )
609
610

                if self.speculative_token_tree is None:
611
612
613
614
615
616
617
                    if self.num_speculative_tokens is None:
                        raise ValueError(
                            "A speculative model was provided, but neither "
                            "`speculative_token_tree` nor `num_speculative_tokens` "
                            "was provided"
                        )

618
                    # Generate chain of tokens.
619
620
621
                    self.speculative_token_tree = str(
                        [(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
                    )
622
623
                else:
                    # Sort the token tree breadth-first.
624
                    tree_choices = ast.literal_eval(self.speculative_token_tree)
625
                    self.speculative_token_tree = str(
626
627
                        sorted(tree_choices, key=lambda t: (len(t), t))
                    )
628

629
                self.draft_tensor_parallel_size = (
630
631
632
                    SpeculativeConfig._verify_and_get_draft_tp(
                        self.target_parallel_config,
                        self.draft_tensor_parallel_size,
633
634
                        self.draft_model_config.hf_config,
                    )
635
636
637
638
639
640
641
                )

                self.draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        self.max_model_len,
                        self.draft_model_config.max_model_len,
                        self.target_model_config.max_model_len,
642
643
                    )
                )
644
645
646

                self.draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
647
648
649
                        self.target_parallel_config, self.draft_tensor_parallel_size
                    )
                )
650
        return self
651

652
653
654
655
    def _validate_suffix_decoding(self):
        if not has_arctic_inference():
            raise ImportError(
                "Arctic Inference is required for suffix decoding. "
656
                "Install via `pip install arctic-inference==0.1.1`."
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
            )
        if self.num_speculative_tokens is None:
            # Suffix decoding decides the actual number of speculative tokens
            # dynamically and treats num_speculative_tokens as a maximum limit.
            self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
            logger.warning(
                "Defaulted num_speculative_tokens to %s for suffix decoding.",
                self.num_speculative_tokens,
            )
        # Validate values
        if self.suffix_decoding_max_tree_depth < 1:
            raise ValueError(
                f"suffix_decoding_max_tree_depth="
                f"{self.suffix_decoding_max_tree_depth} must be >= 1"
            )
        if self.suffix_decoding_max_cached_requests < 0:
            raise ValueError(
                f"suffix_decoding_max_cached_requests="
                f"{self.suffix_decoding_max_cached_requests} must be >= 0"
            )
        if self.suffix_decoding_max_spec_factor < 0:
            raise ValueError(
                f"suffix_decoding_max_spec_factor="
                f"{self.suffix_decoding_max_spec_factor} must be >= 0"
            )
        if not 0 <= self.suffix_decoding_min_token_prob <= 1:
            raise ValueError(
                f"suffix_decoding_min_token_prob="
                f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
            )

688
689
    @staticmethod
    def _maybe_override_draft_max_model_len(
690
        speculative_max_model_len: int | None,
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:
            if speculative_max_model_len > draft_max_model_len:
708
709
710
711
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {draft_max_model_len=}"
                )
712
713

            if speculative_max_model_len > target_max_model_len:
714
715
716
717
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {target_max_model_len=}"
                )
718
719
720
721
722
723
724
725
726
727

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

    @staticmethod
    def _verify_and_get_draft_tp(
728
        target_parallel_config: ParallelConfig,
729
        speculative_draft_tensor_parallel_size: int | None,
730
731
        draft_hf_config: PretrainedConfig,
    ) -> int:
732
733
734
735
736
737
738
739
740
741
742
743
744
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
        """
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
        if speculative_draft_tensor_parallel_size is None:
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "%s cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1",
745
746
                        draft_hf_config.model_type,
                    )
747
            else:
748
                speculative_draft_tensor_parallel_size = (
749
                    target_parallel_config.tensor_parallel_size
750
                )
751
        elif speculative_draft_tensor_parallel_size not in (
752
753
754
            1,
            target_parallel_config.tensor_parallel_size,
        ):
755
756
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be "
757
758
                f"other value than 1 or target model tensor_parallel_size"
            )
759
760
        return speculative_draft_tensor_parallel_size

761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
    def update_arch_(self):
        """
        EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
        architectures-related fields in self.draft_model_config
        """
        self.draft_model_config.hf_text_config = get_hf_text_config(
            self.draft_model_config.hf_config
        )
        self.draft_model_config.model_arch_config = (
            self.draft_model_config.get_model_arch_config()
        )
        model_info, arch = self.draft_model_config.registry.inspect_model_cls(
            self.draft_model_config.architectures,
            self.draft_model_config,
        )
        self.draft_model_config._model_info = model_info
        self.draft_model_config._architecture = arch

779
780
781
782
783
784
785
786
787
788
    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
        draft_parallel_config = ParallelConfig(
789
            pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
790
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
791
792
793
794
            distributed_executor_backend=target_parallel_config.distributed_executor_backend,
            max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
            ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
795
796
797
798
799
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

800
    @model_validator(mode="after")
801
    def _verify_args(self) -> Self:
802
803
804
805
806
807
        if self.tensor_parallel_size is not None:
            raise ValueError(
                "'tensor_parallel_size' is not a valid argument in the "
                "speculative_config. Please pass 'draft_tensor_parallel_size' instead."
            )

808
809
810
811
        if self.num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative model unless the draft model config contains an "
812
813
                "n_predict parameter."
            )
814
815

        if self.num_speculative_tokens <= 0:
816
817
818
819
            raise ValueError(
                "Expected num_speculative_tokens to be greater "
                f"than zero ({self.num_speculative_tokens})."
            )
820
821
822

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
823
824
                self.draft_parallel_config
            )
825

826
        aux_hidden_states_supported = [
827
828
829
830
831
832
            "llama",
            "qwen",
            "minicpm",
            "gpt_oss",
            "hunyuan_vl",
            "hunyuan_v1_dense",
833
            "afmoe",
834
            "nemotron_h",
835
836
837
838
            "deepseek_v2",
            "deepseek_v3",
            "kimi_k2",
            "kimi_k25",
839
            "minimax_m2",
840
            "gemma4",
841
        ]
842
        if (
843
            self.method in ("eagle3", "extract_hidden_states", "dflash")
844
845
846
            and self.target_model_config
            and not any(
                supported_model in self.target_model_config.hf_text_config.model_type
847
                for supported_model in aux_hidden_states_supported
848
849
            )
        ):
850
            raise ValueError(
851
852
                f"{self.method} is only supported for {aux_hidden_states_supported}"
                f" models. Got {self.target_model_config.hf_text_config.model_type=}"
853
            )
854
        self.verify_equal_vocab_size_if_draft_model()
855
856
        return self

857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
    def verify_equal_vocab_size_if_draft_model(self):
        if (
            self.method == "draft_model"
            and self.target_model_config is not None
            and self.draft_model_config is not None
        ):
            target_vocab_size = self.target_model_config.get_vocab_size()
            draft_vocab_size = self.draft_model_config.get_vocab_size()
            if target_vocab_size != draft_vocab_size:
                raise ValueError(
                    f"Target and draft model should have the same vocabulary size. "
                    f"Target model vocab_size={target_vocab_size}. "
                    f"Draft model vocab_size={draft_vocab_size}. "
                    f"Using models with different tokenizers can cause out-of-bounds "
                    f"errors during speculative decoding."
                )

874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
    @property
    def max_num_new_slots_for_drafting(self) -> int:
        """
        Calculate the maximum number of new slots that might be added to the batch
        when drafting.
        """
        slots_per_req = 0  # for serial non-draft-model methods, no change needed
        if self.parallel_drafting:
            # For parallel drafting, we need one new slot per 'masked' token
            slots_per_req = self.num_speculative_tokens - 1
        if self.uses_draft_model():
            # For draft model-based speculation, we need one new slot per request
            # Since we do not slice the draft tokens
            slots_per_req += 1
        return slots_per_req

890
    def use_eagle(self) -> bool:
891
892
893
894
        return self.method in ("eagle", "eagle3", "mtp", "dflash")

    def use_dflash(self) -> bool:
        return self.method == "dflash"
895

896
897
898
    def uses_draft_model(self) -> bool:
        return self.method == "draft_model"

899
900
901
    def uses_extract_hidden_states(self) -> bool:
        return self.method == "extract_hidden_states"

902
903
904
    def use_ngram_gpu(self) -> bool:
        return self.method == "ngram_gpu"

905
906
    def __repr__(self) -> str:
        method = self.method
907
908
909
910
911
        model = (
            None
            if method in ("ngram", "suffix", "extract_hidden_states")
            else self.draft_model_config.model
        )
912
913
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"