speculative.py 34.2 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
5
from typing import TYPE_CHECKING, Any, Literal, get_args
6

7
from pydantic import Field, SkipValidation, model_validator
8
9
from typing_extensions import Self

10
from vllm.config import LoadConfig
11
from vllm.config.model import ModelConfig
12
13
14
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
15
from vllm.transformers_utils.config import get_hf_text_config
16
from vllm.utils.hashing import safe_hash
17
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
18
19
20
21
22
23
24
25

if TYPE_CHECKING:
    from transformers import PretrainedConfig

    import vllm.model_executor.layers.quantization as me_quant
else:
    PretrainedConfig = Any

26
27
28
    me_quant = LazyLoader(
        "model_executor", globals(), "vllm.model_executor.layers.quantization"
    )
29
30
31

logger = init_logger(__name__)

32
MTPModelTypes = Literal[
33
34
35
    "deepseek_mtp",
    "mimo_mtp",
    "glm4_moe_mtp",
36
    "glm4_moe_lite_mtp",
37
    "glm_ocr_mtp",
38
    "ernie_mtp",
39
    "nemotron_h_mtp",
Kyungmin Lee's avatar
Kyungmin Lee committed
40
    "exaone_moe_mtp",
41
    "qwen3_next_mtp",
42
    "qwen3_5_mtp",
43
    "longcat_flash_mtp",
44
    "mtp",
45
    "pangu_ultra_moe_mtp",
csy0225's avatar
csy0225 committed
46
    "step3p5_mtp",
47
48
49
50
51
52
53
54
55
56
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
    "ngram",
    "medusa",
    "mlp_speculator",
    "draft_model",
    "suffix",
    EagleModelTypes,
]
57
58
59
60
61


@config
class SpeculativeConfig:
    """Configuration for speculative decoding."""
62

63
    enforce_eager: bool | None = None
64
    """Override the default enforce_eager from model_config"""
65
    # General speculative decoding control
66
    num_speculative_tokens: int = Field(default=None, gt=0)
67
68
    """The number of speculative tokens, if provided. It will default to the
    number in the draft model config if present, otherwise, it is required."""
69
    model: str | None = None
70
71
    """The name of the draft model, eagle head, or additional weights, if
    provided."""
72
    method: SpeculativeMethod | None = None
73
74
75
76
77
78
79
    """The name of the speculative method to use. If users provide and set the
    `model` param, the speculative method type will be detected automatically
    if possible, if `model` param is not provided, the method name must be
    provided.

    If using `ngram` method, the related configuration `prompt_lookup_max` and
    `prompt_lookup_min` should be considered."""
80
    draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
81
82
    """The degree of the tensor parallelism for the draft model. Can only be 1
    or the same as the target model's tensor parallel size."""
83
84
85
    tensor_parallel_size: int | None = None
    """Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
    warn users when they mistakenly provide the wrong argument."""
86
87

    # Draft model configuration
88
    quantization: me_quant.QuantizationMethods | None = None
89
90
91
    """Quantization method that was used to quantize the draft model weights.
    If `None`, we assume the model weights are not quantized. Note that it only
    takes effect when using the draft model-based speculative method."""
92
    max_model_len: int | None = Field(default=None, ge=1)
93
94
    """The maximum model length of the draft model. Used when testing the
    ability to skip speculation for some sequences."""
95
    revision: str | None = None
96
97
98
    """The specific model version to use for the draft model. It can be a
    branch name, a tag name, or a commit id. If unspecified, will use the
    default version."""
99
    code_revision: str | None = None
100
101
102
103
104
    """The specific revision to use for the draft model code on Hugging Face
    Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
    will use the default version."""

    # Advanced control
105
106
107
108
109
    disable_padded_drafter_batch: bool = False
    """Disable input padding for speculative decoding. If set to True,
    speculative input batches can contain sequences of different lengths,
    which may only be supported by certain attention backends. This currently
    only affects the EAGLE method of speculation."""
110
111
112
113
114
    use_local_argmax_reduction: bool = False
    """Use vocab-parallel local argmax instead of all-gathering full logits
    for draft token generation. Reduces communication from O(vocab_size) to
    O(2 * tp_size) per token. Only applies to greedy draft selection in
    non-tree speculation."""
115
116

    # Ngram proposer configuration
117
    prompt_lookup_max: int | None = Field(default=None, ge=1)
118
119
    """Maximum size of ngram token window when using Ngram proposer, required
    when method is set to ngram."""
120
    prompt_lookup_min: int | None = Field(default=None, ge=1)
121
122
123
    """Minimum size of ngram token window when using Ngram proposer, if
    provided. Defaults to 1."""

124
    # Alternative drafting strategies
125
    speculative_token_tree: str | None = None
126
127
    """Specifies the tree structure for speculative token generation.
    """
128
129
130
131
132
133
    parallel_drafting: bool = False
    """Enable parallel drafting, where all speculative tokens are generated
    in parallel rather than sequentially. This can improve performance but
    requires the speculative model be trained to support parallel drafting.
    Only compatible with EAGLE and draft model methods."""

134
135
136
    # required configuration params passed from engine
    target_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the target model."""
137
    target_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
138
139
140
141
142
    """The parallel configuration for the target model."""

    # params generated in the post-init stage
    draft_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the draft model initialized internal."""
143
    draft_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
144
145
    """The parallel configuration for the draft model initialized internal."""

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    # Suffix decoding configuration
    suffix_decoding_max_tree_depth: int = 24
    """The maximum depth of the suffix decoding global and prompt trees. The
    tree depth limits the sum of the prefix match and speculation lengths."""

    suffix_decoding_max_cached_requests: int = 10000
    """The maximum number of requests to cache in the global suffix tree. If
    exceeded, will trigger eviction in FIFO order. If set to 0, the global
    suffix tree is disabled and past responses are not cached (prompt trees
    are still used)."""

    suffix_decoding_max_spec_factor: float = 1.0
    """The maximum spec factor for suffix decoding. The spec factor controls
    speculation lengths based on the prefix match length: max_spec_tokens =
    max_spec_factor * prefix_match_length."""

    suffix_decoding_min_token_prob: float = 0.1
    """The minimum token probability for suffix decoding. Will only speculate
    tokens with estimated probability (based on frequency counts) greater than
    or equal to this value."""

167
168
169
170
    draft_load_config: LoadConfig | None = None
    """Load config for the draft model. If not specified, will use the load
    config from the target model."""

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        # Eagle3 affects the computation graph because it returns intermediate
        # hidden states in addition to the final hidden state.
        factors.append(self.method == "eagle3")
187
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
188
189
190
191
        return hash_str

    @staticmethod
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
192
        initial_architecture = hf_config.architectures[0]
Jee Jee Li's avatar
Jee Jee Li committed
193
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
194
195
196
            hf_config.model_type = "deepseek_mtp"
        if hf_config.model_type == "deepseek_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
197
198
199
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
            )
200
201
202
203
204
205
206
        if hf_config.model_type in ("pangu_ultra_moe"):
            hf_config.model_type = "pangu_ultra_moe_mtp"
        if hf_config.model_type == "pangu_ultra_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
            )
207
208
209
210

        if hf_config.architectures[0] == "MiMoForCausalLM":
            hf_config.model_type = "mimo_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
211
212
213
214
215
216
217
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["MiMoMTPModel"],
                }
            )
218
219
220
221

        if hf_config.architectures[0] == "Glm4MoeForCausalLM":
            hf_config.model_type = "glm4_moe_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
222
223
224
225
226
227
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeMTPModel"],
                }
            )
228

229
230
231
232
233
234
235
236
237
238
239
        if hf_config.architectures[0] == "Glm4MoeLiteForCausalLM":
            hf_config.model_type = "glm4_moe_lite_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeLiteMTPModel"],
                }
            )

240
241
242
243
244
245
246
247
248
249
250
        if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
            hf_config.model_type = "glm_ocr_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["GlmOcrMTPModel"],
                }
            )

251
252
253
254
        if hf_config.model_type == "ernie4_5_moe":
            hf_config.model_type = "ernie_mtp"
        if hf_config.model_type == "ernie_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
255
256
257
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
            )
258

259
260
261
262
263
264
265
266
267
268
269
270
271
        if (
            hf_config.model_type == "nemotron_h"
            and hasattr(hf_config, "num_nextn_predict_layers")
            and hf_config.num_nextn_predict_layers > 0
        ):
            # Check if this is an MTP variant
            hf_config.model_type = "nemotron_h_mtp"
        if hf_config.model_type == "nemotron_h_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
            )

272
273
274
275
        if hf_config.model_type == "qwen3_next":
            hf_config.model_type = "qwen3_next_mtp"
        if hf_config.model_type == "qwen3_next_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
276
277
278
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
            )
Kyungmin Lee's avatar
Kyungmin Lee committed
279
280
281
282
283
284
285
286
287

        if hf_config.model_type == "exaone_moe":
            hf_config.model_type = "exaone_moe_mtp"
        if hf_config.model_type == "exaone_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
            )

288
289
290
291
292
293
294
295
296
297
        if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
            is_moe = hf_config.model_type == "qwen3_5_moe"
            hf_config.model_type = "qwen3_5_mtp"
            n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
                }
            )
XuruiYang's avatar
XuruiYang committed
298
299
300
        if hf_config.model_type == "longcat_flash":
            hf_config.model_type = "longcat_flash_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
301
302
303
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
            )
304

csy0225's avatar
csy0225 committed
305
306
307
308
309
        if hf_config.model_type == "step3p5":
            hf_config.model_type = "step3p5_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
            hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})

310
311
312
        if initial_architecture == "MistralLarge3ForCausalLM":
            hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})

313
314
315
316
317
318
319
320
321
322
323
        return hf_config

    def __post_init__(self):
        # Note: "method" is a new parameter that helps to extend the
        # configuration of non-model-based proposers, and the "model" parameter
        # will be used to set the draft model, eagle head, or additional weight
        # when needed. If users do not specify "method", the speculative method
        # will be detected automatically if possible. If the speculative method
        # can not be detected, it will be considered as the "draft_model" by
        # default.

324
325
326
327
328
329
330
        # infer method from user args
        if self.method is None:
            if self.model in ("ngram", "[ngram]"):
                self.method = "ngram"
            else:
                self.method = "draft_model"

331
        if self.method in get_args(MTPModelTypes) and self.method != "mtp":
332
333
334
            logger.warning(
                "method `%s` is deprecated and replaced with mtp.", self.method
            )
335
336
            self.method = "mtp"

337
        if self.model is None and self.num_speculative_tokens is not None:
338
            if self.method == "mtp":
339
340
                if self.target_model_config is None:
                    raise ValueError("target_model_config must be present for mtp")
341
                if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
342
                    # FIXME(luccafong): cudagraph with v32 MTP is not supported,
343
344
                    # remove this when the issue is fixed.
                    self.enforce_eager = True
345
346
347
348
349
350
351
352
                # use the draft model from the same model:
                self.model = self.target_model_config.model
                # Align the quantization of draft model for cases such as
                # --quantization fp8 with a bf16 checkpoint.
                if not self.quantization:
                    self.quantization = self.target_model_config.quantization
            elif self.method in ("ngram", "[ngram]"):
                self.model = "ngram"
353
354
            elif self.method == "suffix":
                self.model = "suffix"
355
            else:
356
                raise ValueError(
357
358
                    "num_speculative_tokens was provided but without speculative model."
                )
359
360
361
362
363

        if self.method in ("ngram", "[ngram]"):
            # Unified to "ngram" internally
            self.method = "ngram"
            # Set default values if not provided
364
            if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
365
366
367
368
                # TODO(woosuk): Tune these values. They are arbitrarily chosen.
                self.prompt_lookup_min = 5
                self.prompt_lookup_max = 5
            elif self.prompt_lookup_min is None:
369
370
371
372
373
                if self.prompt_lookup_max is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
374
375
                self.prompt_lookup_min = self.prompt_lookup_max
            elif self.prompt_lookup_max is None:
376
377
378
379
380
                if self.prompt_lookup_min is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
381
382
383
384
385
386
                self.prompt_lookup_max = self.prompt_lookup_min

            # Validate values
            if self.prompt_lookup_min > self.prompt_lookup_max:
                raise ValueError(
                    f"prompt_lookup_min={self.prompt_lookup_min} must "
387
388
                    f"be <= prompt_lookup_max={self.prompt_lookup_max}"
                )
389
390
391
392
393
394

            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            self.draft_model_config = self.target_model_config
            self.draft_parallel_config = self.target_parallel_config
395
396
        elif self.method == "suffix":
            self._validate_suffix_decoding()
397
398
399
400
401
402
403
404
        else:
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if self.model is not None:
                self.draft_model_config = ModelConfig(
                    model=self.model,
                    runner="draft",
405
406
                    tokenizer=self.target_model_config.tokenizer,
                    tokenizer_mode=self.target_model_config.tokenizer_mode,
407
                    trust_remote_code=self.target_model_config.trust_remote_code,
408
409
                    allowed_local_media_path=self.target_model_config.allowed_local_media_path,
                    allowed_media_domains=self.target_model_config.allowed_media_domains,
410
411
412
413
                    dtype=self.target_model_config.dtype,
                    seed=self.target_model_config.seed,
                    revision=self.revision,
                    code_revision=self.code_revision,
414
                    tokenizer_revision=self.target_model_config.tokenizer_revision,
415
                    spec_target_max_model_len=self.target_model_config.max_model_len,
416
417
418
419
                    quantization=self.quantization,
                    enforce_eager=self.target_model_config.enforce_eager,
                    max_logprobs=self.target_model_config.max_logprobs,
                    hf_overrides=SpeculativeConfig.hf_config_override,
420
                    config_format=self.target_model_config.config_format,
421
422
423
                )

                # Automatically detect the method
424
                if self.method in ("eagle", "eagle3"):
425
426
427
428
429
430
431
432
433
434
435
                    pass
                # examples:
                # yuhuili/EAGLE-LLaMA3-Instruct-8B
                # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
                # AngelSlim/Qwen3-8B_eagle3
                elif "eagle-" in self.draft_model_config.model.lower():
                    self.method = "eagle"
                elif "eagle3" in self.draft_model_config.model.lower():
                    self.method = "eagle3"
                elif self.draft_model_config.hf_config.model_type == "medusa":
                    self.method = "medusa"
436
                elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
437
                    self.method = "mlp_speculator"
438
439
440
                elif self.draft_model_config.hf_config.model_type in get_args(
                    MTPModelTypes
                ):
441
                    self.method = "mtp"
442
443
                    if self.num_speculative_tokens > 1:
                        logger.warning(
444
                            "Enabling num_speculative_tokens > 1 will run "
445
446
447
448
449
450
                            "multiple times of forward on same MTP layer"
                            ",which may result in lower acceptance rate"
                        )
                elif self.draft_model_config.hf_config.model_type in (
                    "longcat_flash_mtp"
                ):
XuruiYang's avatar
XuruiYang committed
451
452
453
                    self.method = "longcat_flash_mtp"
                    if self.num_speculative_tokens > 1:
                        logger.warning(
454
455
456
457
                            "LongCat MTP models only have "
                            "one layer. Might need some code changes "
                            "to support multiple layers."
                        )
458
459
                elif self.method == "draft_model":
                    pass
460
461
                else:
                    raise NotImplementedError(
462
                        f"Unsupported speculative method: '{self.method}'"
463
                    )
464
465
466

                # Replace hf_config for EAGLE draft_model
                if self.method in ("eagle", "eagle3"):
467
468
                    from vllm.transformers_utils.configs import SpeculatorsConfig
                    from vllm.transformers_utils.configs.eagle import EAGLEConfig
469

470
471
472
473
                    if isinstance(
                        self.draft_model_config.hf_config,
                        (EAGLEConfig, SpeculatorsConfig),
                    ):
474
475
476
477
478
                        pass
                    else:
                        eagle_config = EAGLEConfig(
                            self.draft_model_config.hf_config,
                            method=self.method,
479
480
                            model_type="eagle",
                        )
481
482
                        # EAGLEConfig primarily updates architectures, so update
                        # all architectures-related fields in draft_model_config
483
                        self.draft_model_config.hf_config = eagle_config
484
485
486
                        self.draft_model_config.hf_text_config = get_hf_text_config(
                            self.draft_model_config.hf_config
                        )
487
488
489
                        self.draft_model_config.model_arch_config = (
                            self.draft_model_config.get_model_arch_config()
                        )
490
491
492
493
494
495
496
497
                        model_info, arch = (
                            self.draft_model_config.registry.inspect_model_cls(
                                self.draft_model_config.architectures,
                                self.draft_model_config,
                            )
                        )
                        self.draft_model_config._model_info = model_info
                        self.draft_model_config._architecture = arch
498

499
500
501
502
503
504
                if self.num_speculative_tokens is not None and hasattr(
                    self.draft_model_config.hf_config, "num_lookahead_tokens"
                ):
                    self.draft_model_config.hf_config.num_lookahead_tokens = (
                        self.num_speculative_tokens
                    )
505

506
507
508
                n_predict = getattr(
                    self.draft_model_config.hf_config, "n_predict", None
                )
509
510
511
512
                if n_predict is not None:
                    if self.num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        self.num_speculative_tokens = n_predict
513
514
515
516
                    elif (
                        self.num_speculative_tokens > n_predict
                        and self.num_speculative_tokens % n_predict != 0
                    ):
517
518
519
                        # Ensure divisibility for MTP module reuse.
                        raise ValueError(
                            f"num_speculative_tokens:{self.num_speculative_tokens}"
520
521
                            f" must be divisible by {n_predict=}"
                        )
522
523

                if self.speculative_token_tree is None:
524
525
526
527
528
529
530
                    if self.num_speculative_tokens is None:
                        raise ValueError(
                            "A speculative model was provided, but neither "
                            "`speculative_token_tree` nor `num_speculative_tokens` "
                            "was provided"
                        )

531
                    # Generate chain of tokens.
532
533
534
                    self.speculative_token_tree = str(
                        [(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
                    )
535
536
                else:
                    # Sort the token tree breadth-first.
537
                    tree_choices = ast.literal_eval(self.speculative_token_tree)
538
                    self.speculative_token_tree = str(
539
540
                        sorted(tree_choices, key=lambda t: (len(t), t))
                    )
541

542
                self.draft_tensor_parallel_size = (
543
544
545
                    SpeculativeConfig._verify_and_get_draft_tp(
                        self.target_parallel_config,
                        self.draft_tensor_parallel_size,
546
547
                        self.draft_model_config.hf_config,
                    )
548
549
550
551
552
553
554
                )

                self.draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        self.max_model_len,
                        self.draft_model_config.max_model_len,
                        self.target_model_config.max_model_len,
555
556
                    )
                )
557
558
559

                self.draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
560
561
562
                        self.target_parallel_config, self.draft_tensor_parallel_size
                    )
                )
563
        return self
564

565
566
567
568
    def _validate_suffix_decoding(self):
        if not has_arctic_inference():
            raise ImportError(
                "Arctic Inference is required for suffix decoding. "
569
                "Install via `pip install arctic-inference==0.1.1`."
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
            )
        if self.num_speculative_tokens is None:
            # Suffix decoding decides the actual number of speculative tokens
            # dynamically and treats num_speculative_tokens as a maximum limit.
            self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
            logger.warning(
                "Defaulted num_speculative_tokens to %s for suffix decoding.",
                self.num_speculative_tokens,
            )
        # Validate values
        if self.suffix_decoding_max_tree_depth < 1:
            raise ValueError(
                f"suffix_decoding_max_tree_depth="
                f"{self.suffix_decoding_max_tree_depth} must be >= 1"
            )
        if self.suffix_decoding_max_cached_requests < 0:
            raise ValueError(
                f"suffix_decoding_max_cached_requests="
                f"{self.suffix_decoding_max_cached_requests} must be >= 0"
            )
        if self.suffix_decoding_max_spec_factor < 0:
            raise ValueError(
                f"suffix_decoding_max_spec_factor="
                f"{self.suffix_decoding_max_spec_factor} must be >= 0"
            )
        if not 0 <= self.suffix_decoding_min_token_prob <= 1:
            raise ValueError(
                f"suffix_decoding_min_token_prob="
                f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
            )

601
602
    @staticmethod
    def _maybe_override_draft_max_model_len(
603
        speculative_max_model_len: int | None,
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:
            if speculative_max_model_len > draft_max_model_len:
621
622
623
624
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {draft_max_model_len=}"
                )
625
626

            if speculative_max_model_len > target_max_model_len:
627
628
629
630
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {target_max_model_len=}"
                )
631
632
633
634
635
636
637
638
639
640

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

    @staticmethod
    def _verify_and_get_draft_tp(
641
        target_parallel_config: ParallelConfig,
642
        speculative_draft_tensor_parallel_size: int | None,
643
644
        draft_hf_config: PretrainedConfig,
    ) -> int:
645
646
647
648
649
650
651
652
653
654
655
656
657
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
        """
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
        if speculative_draft_tensor_parallel_size is None:
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "%s cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1",
658
659
                        draft_hf_config.model_type,
                    )
660
            else:
661
                speculative_draft_tensor_parallel_size = (
662
                    target_parallel_config.tensor_parallel_size
663
                )
664
        elif speculative_draft_tensor_parallel_size not in (
665
666
667
            1,
            target_parallel_config.tensor_parallel_size,
        ):
668
669
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be "
670
671
                f"other value than 1 or target model tensor_parallel_size"
            )
672
673
674
675
676
677
678
679
680
681
682
683
        return speculative_draft_tensor_parallel_size

    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
        draft_parallel_config = ParallelConfig(
684
            pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
685
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
686
687
688
689
            distributed_executor_backend=target_parallel_config.distributed_executor_backend,
            max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
            ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
690
691
692
693
694
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

695
    @model_validator(mode="after")
696
    def _verify_args(self) -> Self:
697
698
699
700
701
702
        if self.tensor_parallel_size is not None:
            raise ValueError(
                "'tensor_parallel_size' is not a valid argument in the "
                "speculative_config. Please pass 'draft_tensor_parallel_size' instead."
            )

703
704
705
706
        if self.num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative model unless the draft model config contains an "
707
708
                "n_predict parameter."
            )
709
710

        if self.num_speculative_tokens <= 0:
711
712
713
714
            raise ValueError(
                "Expected num_speculative_tokens to be greater "
                f"than zero ({self.num_speculative_tokens})."
            )
715
716
717

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
718
719
                self.draft_parallel_config
            )
720

721
722
723
724
725
726
727
        eagle3_target_supported = [
            "llama",
            "qwen",
            "minicpm",
            "gpt_oss",
            "hunyuan_vl",
            "hunyuan_v1_dense",
728
            "afmoe",
729
            "nemotron_h",
730
        ]
731
732
733
734
735
736
737
738
        if (
            self.method == "eagle3"
            and self.target_model_config
            and not any(
                supported_model in self.target_model_config.hf_text_config.model_type
                for supported_model in eagle3_target_supported
            )
        ):
739
740
            raise ValueError(
                f"Eagle3 is only supported for {eagle3_target_supported} models. "  # noqa: E501
741
742
                f"Got {self.target_model_config.hf_text_config.model_type=}"
            )
743
        self.verify_equal_vocab_size_if_draft_model()
744
745
        return self

746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
    def verify_equal_vocab_size_if_draft_model(self):
        if (
            self.method == "draft_model"
            and self.target_model_config is not None
            and self.draft_model_config is not None
        ):
            target_vocab_size = self.target_model_config.get_vocab_size()
            draft_vocab_size = self.draft_model_config.get_vocab_size()
            if target_vocab_size != draft_vocab_size:
                raise ValueError(
                    f"Target and draft model should have the same vocabulary size. "
                    f"Target model vocab_size={target_vocab_size}. "
                    f"Draft model vocab_size={draft_vocab_size}. "
                    f"Using models with different tokenizers can cause out-of-bounds "
                    f"errors during speculative decoding."
                )

763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
    @property
    def max_num_new_slots_for_drafting(self) -> int:
        """
        Calculate the maximum number of new slots that might be added to the batch
        when drafting.
        """
        slots_per_req = 0  # for serial non-draft-model methods, no change needed
        if self.parallel_drafting:
            # For parallel drafting, we need one new slot per 'masked' token
            slots_per_req = self.num_speculative_tokens - 1
        if self.uses_draft_model():
            # For draft model-based speculation, we need one new slot per request
            # Since we do not slice the draft tokens
            slots_per_req += 1
        return slots_per_req

779
    def use_eagle(self) -> bool:
780
        return self.method in ("eagle", "eagle3", "mtp")
781

782
783
784
    def uses_draft_model(self) -> bool:
        return self.method == "draft_model"

785
786
    def __repr__(self) -> str:
        method = self.method
787
        model = None if method in ("ngram", "suffix") else self.draft_model_config.model
788
789
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"