speculative.py 37.7 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
5
import copy
6
from typing import TYPE_CHECKING, Any, Literal, get_args
7

8
from pydantic import Field, SkipValidation, model_validator
9
10
from typing_extensions import Self

11
from vllm.config import LoadConfig
12
from vllm.config.kernel import MoEBackend
13
from vllm.config.model import ModelConfig
14
15
16
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
17
from vllm.transformers_utils.config import get_hf_text_config
18
from vllm.utils.hashing import safe_hash
19
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
20
21
22
23
24
25
26
27

if TYPE_CHECKING:
    from transformers import PretrainedConfig

    import vllm.model_executor.layers.quantization as me_quant
else:
    PretrainedConfig = Any

28
29
30
    me_quant = LazyLoader(
        "model_executor", globals(), "vllm.model_executor.layers.quantization"
    )
31
32
33

logger = init_logger(__name__)

34
MTPModelTypes = Literal[
35
36
37
    "deepseek_mtp",
    "mimo_mtp",
    "glm4_moe_mtp",
38
    "glm4_moe_lite_mtp",
39
    "glm_ocr_mtp",
40
    "ernie_mtp",
41
    "nemotron_h_mtp",
Kyungmin Lee's avatar
Kyungmin Lee committed
42
    "exaone_moe_mtp",
43
    "qwen3_next_mtp",
44
    "qwen3_5_mtp",
45
    "longcat_flash_mtp",
46
    "mtp",
47
    "pangu_ultra_moe_mtp",
csy0225's avatar
csy0225 committed
48
    "step3p5_mtp",
49
]
50
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
51
NgramGPUTypes = Literal["ngram_gpu"]
52
53
54
55
56
57
58
SpeculativeMethod = Literal[
    "ngram",
    "medusa",
    "mlp_speculator",
    "draft_model",
    "suffix",
    EagleModelTypes,
59
    NgramGPUTypes,
60
]
61
RejectionSampleMethod = Literal["strict", "probabilistic", "synthetic"]
62
63
64
65
66


@config
class SpeculativeConfig:
    """Configuration for speculative decoding."""
67

68
    enforce_eager: bool | None = None
69
    """Override the default enforce_eager from model_config"""
70
    # General speculative decoding control
71
    num_speculative_tokens: int = Field(default=None, gt=0)  # type: ignore[assignment]
72
73
    """The number of speculative tokens, if provided. It will default to the
    number in the draft model config if present, otherwise, it is required."""
74
    model: str | None = None
75
76
    """The name of the draft model, eagle head, or additional weights, if
    provided."""
77
    method: SpeculativeMethod | None = None
78
79
80
81
82
83
84
    """The name of the speculative method to use. If users provide and set the
    `model` param, the speculative method type will be detected automatically
    if possible, if `model` param is not provided, the method name must be
    provided.

    If using `ngram` method, the related configuration `prompt_lookup_max` and
    `prompt_lookup_min` should be considered."""
85
    draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
86
87
    """The degree of the tensor parallelism for the draft model. Can only be 1
    or the same as the target model's tensor parallel size."""
88
89
90
    tensor_parallel_size: int | None = None
    """Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
    warn users when they mistakenly provide the wrong argument."""
91
92

    # Draft model configuration
93
    quantization: me_quant.QuantizationMethods | str | None = None
94
95
96
    """Quantization method that was used to quantize the draft model weights.
    If `None`, we assume the model weights are not quantized. Note that it only
    takes effect when using the draft model-based speculative method."""
97
98
99
100
101
    moe_backend: MoEBackend | None = None
    """MoE backend to use for the draft model. When `None`, the draft model
    inherits the target model's `--moe-backend` setting. Useful when the
    drafter and generator require different MoE kernels (e.g. quantized
    generator with unquantized drafter)."""
102
    max_model_len: int | None = Field(default=None, ge=1)
103
104
    """The maximum model length of the draft model. Used when testing the
    ability to skip speculation for some sequences."""
105
    revision: str | None = None
106
107
108
    """The specific model version to use for the draft model. It can be a
    branch name, a tag name, or a commit id. If unspecified, will use the
    default version."""
109
    code_revision: str | None = None
110
111
112
113
114
    """The specific revision to use for the draft model code on Hugging Face
    Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
    will use the default version."""

    # Advanced control
115
116
117
118
119
    disable_padded_drafter_batch: bool = False
    """Disable input padding for speculative decoding. If set to True,
    speculative input batches can contain sequences of different lengths,
    which may only be supported by certain attention backends. This currently
    only affects the EAGLE method of speculation."""
120
121
122
123
124
    use_local_argmax_reduction: bool = False
    """Use vocab-parallel local argmax instead of all-gathering full logits
    for draft token generation. Reduces communication from O(vocab_size) to
    O(2 * tp_size) per token. Only applies to greedy draft selection in
    non-tree speculation."""
125
126

    # Ngram proposer configuration
127
    prompt_lookup_max: int | None = Field(default=None, ge=1)
128
129
    """Maximum size of ngram token window when using Ngram proposer, required
    when method is set to ngram."""
130
    prompt_lookup_min: int | None = Field(default=None, ge=1)
131
132
133
    """Minimum size of ngram token window when using Ngram proposer, if
    provided. Defaults to 1."""

134
    # Alternative drafting strategies
135
    speculative_token_tree: str | None = None
136
137
    """Specifies the tree structure for speculative token generation.
    """
138
139
140
141
142
143
    parallel_drafting: bool = False
    """Enable parallel drafting, where all speculative tokens are generated
    in parallel rather than sequentially. This can improve performance but
    requires the speculative model be trained to support parallel drafting.
    Only compatible with EAGLE and draft model methods."""

144
145
146
    # required configuration params passed from engine
    target_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the target model."""
147
    target_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
148
149
150
151
152
    """The parallel configuration for the target model."""

    # params generated in the post-init stage
    draft_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the draft model initialized internal."""
153
    draft_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
154
155
    """The parallel configuration for the draft model initialized internal."""

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    # Suffix decoding configuration
    suffix_decoding_max_tree_depth: int = 24
    """The maximum depth of the suffix decoding global and prompt trees. The
    tree depth limits the sum of the prefix match and speculation lengths."""

    suffix_decoding_max_cached_requests: int = 10000
    """The maximum number of requests to cache in the global suffix tree. If
    exceeded, will trigger eviction in FIFO order. If set to 0, the global
    suffix tree is disabled and past responses are not cached (prompt trees
    are still used)."""

    suffix_decoding_max_spec_factor: float = 1.0
    """The maximum spec factor for suffix decoding. The spec factor controls
    speculation lengths based on the prefix match length: max_spec_tokens =
    max_spec_factor * prefix_match_length."""

    suffix_decoding_min_token_prob: float = 0.1
    """The minimum token probability for suffix decoding. Will only speculate
    tokens with estimated probability (based on frequency counts) greater than
    or equal to this value."""

177
178
179
180
    draft_load_config: LoadConfig | None = None
    """Load config for the draft model. If not specified, will use the load
    config from the target model."""

181
182
183
184
185
186
    rejection_sample_method: RejectionSampleMethod = "strict"
    """Whether to use strict (target and draft sampled tokens match exactly)
    or probabilistic rejection sampling. Both respect the target model
    distribution, but the latter yields a higher acceptance rate at the cost
    of more memory to cache draft logits."""

187
188
189
190
191
192
193
    synthetic_acceptance_rate: float | None = None
    """Average acceptance rate for synthetic rejection sampling. Draft
    tokens are accepted with a position-dependent probability that decays
    geometrically, calibrated so that the mean rate across all speculative
    positions equals this value. Only used when rejection_sample_method
    is 'synthetic'. Must be in [0, 1]."""

194
195
196
197
198
199
200
201
202
203
204
205
206
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        # Eagle3 and extract_hidden_states affect the computation graph because
        # they return intermediate hidden states in addition to the final hidden state.
        uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
        factors.append(uses_aux_hidden_states)

        # The specific layers used also affect the computation graph
        if uses_aux_hidden_states and self.draft_model_config is not None:
            layer_ids = getattr(
                self.draft_model_config.hf_config,
                "eagle_aux_hidden_state_layer_ids",
                None,
            )
            if layer_ids is not None:
                # Convert to tuple to make it hashable
                factors.append(tuple(layer_ids))

223
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
224
225
226
227
        return hash_str

    @staticmethod
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
228
        initial_architecture = hf_config.architectures[0]
Jee Jee Li's avatar
Jee Jee Li committed
229
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
230
231
232
            hf_config.model_type = "deepseek_mtp"
        if hf_config.model_type == "deepseek_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
233
234
235
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
            )
236
237
238
239
240
241
242
        if hf_config.model_type in ("pangu_ultra_moe"):
            hf_config.model_type = "pangu_ultra_moe_mtp"
        if hf_config.model_type == "pangu_ultra_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
            )
243
244
245
246

        if hf_config.architectures[0] == "MiMoForCausalLM":
            hf_config.model_type = "mimo_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
247
248
249
250
251
252
253
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["MiMoMTPModel"],
                }
            )
254
255
256
257

        if hf_config.architectures[0] == "Glm4MoeForCausalLM":
            hf_config.model_type = "glm4_moe_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
258
259
260
261
262
263
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeMTPModel"],
                }
            )
264

265
266
267
268
269
270
271
272
273
274
275
        if hf_config.architectures[0] == "Glm4MoeLiteForCausalLM":
            hf_config.model_type = "glm4_moe_lite_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeLiteMTPModel"],
                }
            )

276
277
278
279
280
281
282
283
284
285
286
        if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
            hf_config.model_type = "glm_ocr_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["GlmOcrMTPModel"],
                }
            )

287
288
289
290
        if hf_config.model_type == "ernie4_5_moe":
            hf_config.model_type = "ernie_mtp"
        if hf_config.model_type == "ernie_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
291
292
293
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
            )
294

295
        if (
296
            hf_config.model_type in {"nemotron_h", "nemotron_h_puzzle"}
297
298
299
300
301
302
303
304
305
306
307
            and hasattr(hf_config, "num_nextn_predict_layers")
            and hf_config.num_nextn_predict_layers > 0
        ):
            # Check if this is an MTP variant
            hf_config.model_type = "nemotron_h_mtp"
        if hf_config.model_type == "nemotron_h_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
            )

308
309
310
311
        if hf_config.model_type == "qwen3_next":
            hf_config.model_type = "qwen3_next_mtp"
        if hf_config.model_type == "qwen3_next_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
312
313
314
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
            )
Kyungmin Lee's avatar
Kyungmin Lee committed
315
316
317
318
319
320
321
322
323

        if hf_config.model_type == "exaone_moe":
            hf_config.model_type = "exaone_moe_mtp"
        if hf_config.model_type == "exaone_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
            )

324
325
326
327
328
329
330
331
332
333
        if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
            is_moe = hf_config.model_type == "qwen3_5_moe"
            hf_config.model_type = "qwen3_5_mtp"
            n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
                }
            )
XuruiYang's avatar
XuruiYang committed
334
335
336
        if hf_config.model_type == "longcat_flash":
            hf_config.model_type = "longcat_flash_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
337
338
339
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
            )
340

csy0225's avatar
csy0225 committed
341
342
343
344
345
        if hf_config.model_type == "step3p5":
            hf_config.model_type = "step3p5_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
            hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})

346
347
348
        if initial_architecture == "MistralLarge3ForCausalLM":
            hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})

349
350
351
352
353
354
355
356
357
358
359
        return hf_config

    def __post_init__(self):
        # Note: "method" is a new parameter that helps to extend the
        # configuration of non-model-based proposers, and the "model" parameter
        # will be used to set the draft model, eagle head, or additional weight
        # when needed. If users do not specify "method", the speculative method
        # will be detected automatically if possible. If the speculative method
        # can not be detected, it will be considered as the "draft_model" by
        # default.

360
361
362
363
364
365
366
        # infer method from user args
        if self.method is None:
            if self.model in ("ngram", "[ngram]"):
                self.method = "ngram"
            else:
                self.method = "draft_model"

367
        if self.method in get_args(MTPModelTypes) and self.method != "mtp":
368
369
370
            logger.warning(
                "method `%s` is deprecated and replaced with mtp.", self.method
            )
371
372
            self.method = "mtp"

373
        if self.model is None and self.num_speculative_tokens is not None:
374
            if self.method == "mtp":
375
376
                if self.target_model_config is None:
                    raise ValueError("target_model_config must be present for mtp")
377
                if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
378
                    # FIXME(luccafong): cudagraph with v32 MTP is not supported,
379
380
                    # remove this when the issue is fixed.
                    self.enforce_eager = True
381
382
383
384
385
386
387
388
                # use the draft model from the same model:
                self.model = self.target_model_config.model
                # Align the quantization of draft model for cases such as
                # --quantization fp8 with a bf16 checkpoint.
                if not self.quantization:
                    self.quantization = self.target_model_config.quantization
            elif self.method in ("ngram", "[ngram]"):
                self.model = "ngram"
389
390
            elif self.method == "ngram_gpu":
                self.model = "ngram_gpu"
391
392
            elif self.method == "suffix":
                self.model = "suffix"
393
394
            elif self.method == "extract_hidden_states":
                self.model = "extract_hidden_states"
395
            else:
396
                raise ValueError(
397
398
                    "num_speculative_tokens was provided but without speculative model."
                )
399
400
401

        if self.method in ("ngram", "[ngram]"):
            self.method = "ngram"
402
403

        if self.method in ("ngram", "ngram_gpu"):
404
            # Set default values if not provided
405
            if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
406
407
408
409
                # TODO(woosuk): Tune these values. They are arbitrarily chosen.
                self.prompt_lookup_min = 5
                self.prompt_lookup_max = 5
            elif self.prompt_lookup_min is None:
410
411
412
413
414
                if self.prompt_lookup_max is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
415
416
                self.prompt_lookup_min = self.prompt_lookup_max
            elif self.prompt_lookup_max is None:
417
418
419
420
421
                if self.prompt_lookup_min is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
422
423
424
425
426
427
                self.prompt_lookup_max = self.prompt_lookup_min

            # Validate values
            if self.prompt_lookup_min > self.prompt_lookup_max:
                raise ValueError(
                    f"prompt_lookup_min={self.prompt_lookup_min} must "
428
429
                    f"be <= prompt_lookup_max={self.prompt_lookup_max}"
                )
430
431
432
433
434
435

            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            self.draft_model_config = self.target_model_config
            self.draft_parallel_config = self.target_parallel_config
436
437
        elif self.method == "suffix":
            self._validate_suffix_decoding()
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
        elif self.method == "extract_hidden_states":
            from vllm.transformers_utils.configs.extract_hidden_states import (
                ExtractHiddenStatesConfig,
            )

            # ExtractHiddenStatesModel is instantiated manually in load_model()
            # We just need to store the target model config for KV cache shape info
            self.model = "extract_hidden_states"
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if hasattr(self.draft_model_config, "hf_config"):
                hf_config = self.draft_model_config.hf_config.to_dict()
            elif (
                isinstance(self.draft_model_config, dict)
                and "hf_config" in self.draft_model_config
            ):
                hf_config = self.draft_model_config["hf_config"]
            else:
                hf_config = {}

            self.draft_model_config = copy.copy(self.target_model_config)
            self.draft_model_config.hf_config = ExtractHiddenStatesConfig(
                self.draft_model_config.hf_config, **hf_config
            )
            self.update_arch_()
            self.draft_parallel_config = self.target_parallel_config

466
467
468
469
470
471
472
473
        else:
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if self.model is not None:
                self.draft_model_config = ModelConfig(
                    model=self.model,
                    runner="draft",
474
475
                    tokenizer=self.target_model_config.tokenizer,
                    tokenizer_mode=self.target_model_config.tokenizer_mode,
476
                    trust_remote_code=self.target_model_config.trust_remote_code,
477
478
                    allowed_local_media_path=self.target_model_config.allowed_local_media_path,
                    allowed_media_domains=self.target_model_config.allowed_media_domains,
479
480
481
482
                    dtype=self.target_model_config.dtype,
                    seed=self.target_model_config.seed,
                    revision=self.revision,
                    code_revision=self.code_revision,
483
                    tokenizer_revision=self.target_model_config.tokenizer_revision,
484
                    spec_target_max_model_len=self.target_model_config.max_model_len,
485
486
487
488
                    quantization=self.quantization,
                    enforce_eager=self.target_model_config.enforce_eager,
                    max_logprobs=self.target_model_config.max_logprobs,
                    hf_overrides=SpeculativeConfig.hf_config_override,
489
                    config_format=self.target_model_config.config_format,
490
491
492
                )

                # Automatically detect the method
493
                if self.method in ("eagle", "eagle3"):
494
495
496
497
498
499
500
501
502
503
504
                    pass
                # examples:
                # yuhuili/EAGLE-LLaMA3-Instruct-8B
                # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
                # AngelSlim/Qwen3-8B_eagle3
                elif "eagle-" in self.draft_model_config.model.lower():
                    self.method = "eagle"
                elif "eagle3" in self.draft_model_config.model.lower():
                    self.method = "eagle3"
                elif self.draft_model_config.hf_config.model_type == "medusa":
                    self.method = "medusa"
505
                elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
506
                    self.method = "mlp_speculator"
507
508
509
                elif self.draft_model_config.hf_config.model_type in get_args(
                    MTPModelTypes
                ):
510
                    self.method = "mtp"
511
512
                    if self.num_speculative_tokens > 1:
                        logger.warning(
513
                            "Enabling num_speculative_tokens > 1 will run "
514
515
516
517
518
519
                            "multiple times of forward on same MTP layer"
                            ",which may result in lower acceptance rate"
                        )
                elif self.draft_model_config.hf_config.model_type in (
                    "longcat_flash_mtp"
                ):
XuruiYang's avatar
XuruiYang committed
520
521
522
                    self.method = "longcat_flash_mtp"
                    if self.num_speculative_tokens > 1:
                        logger.warning(
523
524
525
526
                            "LongCat MTP models only have "
                            "one layer. Might need some code changes "
                            "to support multiple layers."
                        )
527
528
                elif self.method == "draft_model":
                    pass
529
530
                else:
                    raise NotImplementedError(
531
                        f"Unsupported speculative method: '{self.method}'"
532
                    )
533
534
535

                # Replace hf_config for EAGLE draft_model
                if self.method in ("eagle", "eagle3"):
536
                    from vllm.transformers_utils.configs.eagle import EAGLEConfig
537
538
539
                    from vllm.transformers_utils.configs.speculators import (
                        SpeculatorsConfig,
                    )
540

541
542
543
544
                    if isinstance(
                        self.draft_model_config.hf_config,
                        (EAGLEConfig, SpeculatorsConfig),
                    ):
545
546
547
548
549
                        pass
                    else:
                        eagle_config = EAGLEConfig(
                            self.draft_model_config.hf_config,
                            method=self.method,
550
551
                            model_type="eagle",
                        )
552
                        self.draft_model_config.hf_config = eagle_config
553
                        self.update_arch_()
554

555
556
557
558
559
560
                if self.num_speculative_tokens is not None and hasattr(
                    self.draft_model_config.hf_config, "num_lookahead_tokens"
                ):
                    self.draft_model_config.hf_config.num_lookahead_tokens = (
                        self.num_speculative_tokens
                    )
561

562
563
564
                n_predict = getattr(
                    self.draft_model_config.hf_config, "n_predict", None
                )
565
566
567
568
                if n_predict is not None:
                    if self.num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        self.num_speculative_tokens = n_predict
569
570
571
572
                    elif (
                        self.num_speculative_tokens > n_predict
                        and self.num_speculative_tokens % n_predict != 0
                    ):
573
574
575
                        # Ensure divisibility for MTP module reuse.
                        raise ValueError(
                            f"num_speculative_tokens:{self.num_speculative_tokens}"
576
577
                            f" must be divisible by {n_predict=}"
                        )
578
579

                if self.speculative_token_tree is None:
580
581
582
583
584
585
586
                    if self.num_speculative_tokens is None:
                        raise ValueError(
                            "A speculative model was provided, but neither "
                            "`speculative_token_tree` nor `num_speculative_tokens` "
                            "was provided"
                        )

587
                    # Generate chain of tokens.
588
589
590
                    self.speculative_token_tree = str(
                        [(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
                    )
591
592
                else:
                    # Sort the token tree breadth-first.
593
                    tree_choices = ast.literal_eval(self.speculative_token_tree)
594
                    self.speculative_token_tree = str(
595
596
                        sorted(tree_choices, key=lambda t: (len(t), t))
                    )
597

598
                self.draft_tensor_parallel_size = (
599
600
601
                    SpeculativeConfig._verify_and_get_draft_tp(
                        self.target_parallel_config,
                        self.draft_tensor_parallel_size,
602
603
                        self.draft_model_config.hf_config,
                    )
604
605
606
607
608
609
610
                )

                self.draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        self.max_model_len,
                        self.draft_model_config.max_model_len,
                        self.target_model_config.max_model_len,
611
612
                    )
                )
613
614
615

                self.draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
616
617
618
                        self.target_parallel_config, self.draft_tensor_parallel_size
                    )
                )
619
        return self
620

621
622
623
624
    def _validate_suffix_decoding(self):
        if not has_arctic_inference():
            raise ImportError(
                "Arctic Inference is required for suffix decoding. "
625
                "Install via `pip install arctic-inference==0.1.1`."
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
            )
        if self.num_speculative_tokens is None:
            # Suffix decoding decides the actual number of speculative tokens
            # dynamically and treats num_speculative_tokens as a maximum limit.
            self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
            logger.warning(
                "Defaulted num_speculative_tokens to %s for suffix decoding.",
                self.num_speculative_tokens,
            )
        # Validate values
        if self.suffix_decoding_max_tree_depth < 1:
            raise ValueError(
                f"suffix_decoding_max_tree_depth="
                f"{self.suffix_decoding_max_tree_depth} must be >= 1"
            )
        if self.suffix_decoding_max_cached_requests < 0:
            raise ValueError(
                f"suffix_decoding_max_cached_requests="
                f"{self.suffix_decoding_max_cached_requests} must be >= 0"
            )
        if self.suffix_decoding_max_spec_factor < 0:
            raise ValueError(
                f"suffix_decoding_max_spec_factor="
                f"{self.suffix_decoding_max_spec_factor} must be >= 0"
            )
        if not 0 <= self.suffix_decoding_min_token_prob <= 1:
            raise ValueError(
                f"suffix_decoding_min_token_prob="
                f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
            )

657
658
    @staticmethod
    def _maybe_override_draft_max_model_len(
659
        speculative_max_model_len: int | None,
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:
            if speculative_max_model_len > draft_max_model_len:
677
678
679
680
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {draft_max_model_len=}"
                )
681
682

            if speculative_max_model_len > target_max_model_len:
683
684
685
686
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {target_max_model_len=}"
                )
687
688
689
690
691
692
693
694
695
696

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

    @staticmethod
    def _verify_and_get_draft_tp(
697
        target_parallel_config: ParallelConfig,
698
        speculative_draft_tensor_parallel_size: int | None,
699
700
        draft_hf_config: PretrainedConfig,
    ) -> int:
701
702
703
704
705
706
707
708
709
710
711
712
713
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
        """
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
        if speculative_draft_tensor_parallel_size is None:
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "%s cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1",
714
715
                        draft_hf_config.model_type,
                    )
716
            else:
717
                speculative_draft_tensor_parallel_size = (
718
                    target_parallel_config.tensor_parallel_size
719
                )
720
        elif speculative_draft_tensor_parallel_size not in (
721
722
723
            1,
            target_parallel_config.tensor_parallel_size,
        ):
724
725
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be "
726
727
                f"other value than 1 or target model tensor_parallel_size"
            )
728
729
        return speculative_draft_tensor_parallel_size

730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
    def update_arch_(self):
        """
        EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
        architectures-related fields in self.draft_model_config
        """
        self.draft_model_config.hf_text_config = get_hf_text_config(
            self.draft_model_config.hf_config
        )
        self.draft_model_config.model_arch_config = (
            self.draft_model_config.get_model_arch_config()
        )
        model_info, arch = self.draft_model_config.registry.inspect_model_cls(
            self.draft_model_config.architectures,
            self.draft_model_config,
        )
        self.draft_model_config._model_info = model_info
        self.draft_model_config._architecture = arch

748
749
750
751
752
753
754
755
756
757
    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
        draft_parallel_config = ParallelConfig(
758
            pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
759
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
760
761
762
763
            distributed_executor_backend=target_parallel_config.distributed_executor_backend,
            max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
            ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
764
765
766
767
768
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

769
    @model_validator(mode="after")
770
    def _verify_args(self) -> Self:
771
772
773
774
775
776
        if self.tensor_parallel_size is not None:
            raise ValueError(
                "'tensor_parallel_size' is not a valid argument in the "
                "speculative_config. Please pass 'draft_tensor_parallel_size' instead."
            )

777
778
779
780
        if self.num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative model unless the draft model config contains an "
781
782
                "n_predict parameter."
            )
783
784

        if self.num_speculative_tokens <= 0:
785
786
787
788
            raise ValueError(
                "Expected num_speculative_tokens to be greater "
                f"than zero ({self.num_speculative_tokens})."
            )
789
790
791

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
792
793
                self.draft_parallel_config
            )
794

795
        aux_hidden_states_supported = [
796
797
798
799
800
801
            "llama",
            "qwen",
            "minicpm",
            "gpt_oss",
            "hunyuan_vl",
            "hunyuan_v1_dense",
802
            "afmoe",
803
            "nemotron_h",
804
805
806
807
            "deepseek_v2",
            "deepseek_v3",
            "kimi_k2",
            "kimi_k25",
808
809
            "minimax_m2",
            "gemma4",
810
        ]
811
        if (
812
            self.method in ("eagle3", "extract_hidden_states")
813
814
815
            and self.target_model_config
            and not any(
                supported_model in self.target_model_config.hf_text_config.model_type
816
                for supported_model in aux_hidden_states_supported
817
818
            )
        ):
819
            raise ValueError(
820
821
                f"{self.method} is only supported for {aux_hidden_states_supported}"
                f" models. Got {self.target_model_config.hf_text_config.model_type=}"
822
            )
823
        self.verify_equal_vocab_size_if_draft_model()
824
825
        return self

826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
    def verify_equal_vocab_size_if_draft_model(self):
        if (
            self.method == "draft_model"
            and self.target_model_config is not None
            and self.draft_model_config is not None
        ):
            target_vocab_size = self.target_model_config.get_vocab_size()
            draft_vocab_size = self.draft_model_config.get_vocab_size()
            if target_vocab_size != draft_vocab_size:
                raise ValueError(
                    f"Target and draft model should have the same vocabulary size. "
                    f"Target model vocab_size={target_vocab_size}. "
                    f"Draft model vocab_size={draft_vocab_size}. "
                    f"Using models with different tokenizers can cause out-of-bounds "
                    f"errors during speculative decoding."
                )

843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
    @property
    def max_num_new_slots_for_drafting(self) -> int:
        """
        Calculate the maximum number of new slots that might be added to the batch
        when drafting.
        """
        slots_per_req = 0  # for serial non-draft-model methods, no change needed
        if self.parallel_drafting:
            # For parallel drafting, we need one new slot per 'masked' token
            slots_per_req = self.num_speculative_tokens - 1
        if self.uses_draft_model():
            # For draft model-based speculation, we need one new slot per request
            # Since we do not slice the draft tokens
            slots_per_req += 1
        return slots_per_req

859
    def use_eagle(self) -> bool:
860
        return self.method in ("eagle", "eagle3", "mtp")
861

862
863
864
    def uses_draft_model(self) -> bool:
        return self.method == "draft_model"

865
866
867
    def uses_extract_hidden_states(self) -> bool:
        return self.method == "extract_hidden_states"

868
869
870
    def use_ngram_gpu(self) -> bool:
        return self.method == "ngram_gpu"

871
872
    def __repr__(self) -> str:
        method = self.method
873
874
875
876
877
        model = (
            None
            if method in ("ngram", "suffix", "extract_hidden_states")
            else self.draft_model_config.model
        )
878
879
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"