speculative.py 34.1 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
5
from typing import TYPE_CHECKING, Any, Literal, get_args
6

7
from pydantic import Field, SkipValidation, model_validator
8
9
from typing_extensions import Self

10
from vllm.config import LoadConfig
11
from vllm.config.model import ModelConfig
12
13
14
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
15
from vllm.transformers_utils.config import get_hf_text_config
16
from vllm.utils.hashing import safe_hash
17
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
18
19
20
21
22
23
24
25

if TYPE_CHECKING:
    from transformers import PretrainedConfig

    import vllm.model_executor.layers.quantization as me_quant
else:
    PretrainedConfig = Any

26
27
28
    me_quant = LazyLoader(
        "model_executor", globals(), "vllm.model_executor.layers.quantization"
    )
29
30
31

logger = init_logger(__name__)

32
MTPModelTypes = Literal[
33
34
35
    "deepseek_mtp",
    "mimo_mtp",
    "glm4_moe_mtp",
36
    "glm4_moe_lite_mtp",
37
    "glm_ocr_mtp",
38
    "ernie_mtp",
Kyungmin Lee's avatar
Kyungmin Lee committed
39
    "exaone_moe_mtp",
40
    "qwen3_next_mtp",
41
    "qwen3_5_mtp",
42
    "longcat_flash_mtp",
43
    "mtp",
44
    "pangu_ultra_moe_mtp",
csy0225's avatar
csy0225 committed
45
    "step3p5_mtp",
46
47
48
49
50
51
52
53
54
55
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
    "ngram",
    "medusa",
    "mlp_speculator",
    "draft_model",
    "suffix",
    EagleModelTypes,
]
56
57
58
59
60


@config
class SpeculativeConfig:
    """Configuration for speculative decoding."""
61

62
    enforce_eager: bool | None = None
63
    """Override the default enforce_eager from model_config"""
64
    # General speculative decoding control
65
    num_speculative_tokens: int = Field(default=None, gt=0)
66
67
    """The number of speculative tokens, if provided. It will default to the
    number in the draft model config if present, otherwise, it is required."""
68
    model: str | None = None
69
70
    """The name of the draft model, eagle head, or additional weights, if
    provided."""
71
    method: SpeculativeMethod | None = None
72
73
74
75
76
77
78
    """The name of the speculative method to use. If users provide and set the
    `model` param, the speculative method type will be detected automatically
    if possible, if `model` param is not provided, the method name must be
    provided.

    If using `ngram` method, the related configuration `prompt_lookup_max` and
    `prompt_lookup_min` should be considered."""
79
    draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
80
81
    """The degree of the tensor parallelism for the draft model. Can only be 1
    or the same as the target model's tensor parallel size."""
82
83
84
    tensor_parallel_size: int | None = None
    """Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
    warn users when they mistakenly provide the wrong argument."""
85
86

    # Draft model configuration
87
    quantization: me_quant.QuantizationMethods | None = None
88
89
90
    """Quantization method that was used to quantize the draft model weights.
    If `None`, we assume the model weights are not quantized. Note that it only
    takes effect when using the draft model-based speculative method."""
91
    max_model_len: int | None = Field(default=None, ge=1)
92
93
    """The maximum model length of the draft model. Used when testing the
    ability to skip speculation for some sequences."""
94
    revision: str | None = None
95
96
97
    """The specific model version to use for the draft model. It can be a
    branch name, a tag name, or a commit id. If unspecified, will use the
    default version."""
98
    code_revision: str | None = None
99
100
101
102
103
    """The specific revision to use for the draft model code on Hugging Face
    Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
    will use the default version."""

    # Advanced control
104
    disable_by_batch_size: int | None = Field(default=None, ge=2)
105
106
    """Disable speculative decoding for new incoming requests when the number
    of enqueued requests is larger than this value, if provided."""
107
108
109
110
111
    disable_padded_drafter_batch: bool = False
    """Disable input padding for speculative decoding. If set to True,
    speculative input batches can contain sequences of different lengths,
    which may only be supported by certain attention backends. This currently
    only affects the EAGLE method of speculation."""
112
113
114
115
116
    use_local_argmax_reduction: bool = False
    """Use vocab-parallel local argmax instead of all-gathering full logits
    for draft token generation. Reduces communication from O(vocab_size) to
    O(2 * tp_size) per token. Only applies to greedy draft selection in
    non-tree speculation."""
117
118

    # Ngram proposer configuration
119
    prompt_lookup_max: int | None = Field(default=None, ge=1)
120
121
    """Maximum size of ngram token window when using Ngram proposer, required
    when method is set to ngram."""
122
    prompt_lookup_min: int | None = Field(default=None, ge=1)
123
124
125
    """Minimum size of ngram token window when using Ngram proposer, if
    provided. Defaults to 1."""

126
    # Alternative drafting strategies
127
    speculative_token_tree: str | None = None
128
129
    """Specifies the tree structure for speculative token generation.
    """
130
131
132
133
134
135
    parallel_drafting: bool = False
    """Enable parallel drafting, where all speculative tokens are generated
    in parallel rather than sequentially. This can improve performance but
    requires the speculative model be trained to support parallel drafting.
    Only compatible with EAGLE and draft model methods."""

136
137
138
    # required configuration params passed from engine
    target_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the target model."""
139
    target_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
140
141
142
143
144
    """The parallel configuration for the target model."""

    # params generated in the post-init stage
    draft_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the draft model initialized internal."""
145
    draft_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
146
147
    """The parallel configuration for the draft model initialized internal."""

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    # Suffix decoding configuration
    suffix_decoding_max_tree_depth: int = 24
    """The maximum depth of the suffix decoding global and prompt trees. The
    tree depth limits the sum of the prefix match and speculation lengths."""

    suffix_decoding_max_cached_requests: int = 10000
    """The maximum number of requests to cache in the global suffix tree. If
    exceeded, will trigger eviction in FIFO order. If set to 0, the global
    suffix tree is disabled and past responses are not cached (prompt trees
    are still used)."""

    suffix_decoding_max_spec_factor: float = 1.0
    """The maximum spec factor for suffix decoding. The spec factor controls
    speculation lengths based on the prefix match length: max_spec_tokens =
    max_spec_factor * prefix_match_length."""

    suffix_decoding_min_token_prob: float = 0.1
    """The minimum token probability for suffix decoding. Will only speculate
    tokens with estimated probability (based on frequency counts) greater than
    or equal to this value."""

169
170
171
172
    draft_load_config: LoadConfig | None = None
    """Load config for the draft model. If not specified, will use the load
    config from the target model."""

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        # Eagle3 affects the computation graph because it returns intermediate
        # hidden states in addition to the final hidden state.
        factors.append(self.method == "eagle3")
189
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
190
191
192
193
        return hash_str

    @staticmethod
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
194
        initial_architecture = hf_config.architectures[0]
Jee Jee Li's avatar
Jee Jee Li committed
195
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
196
197
198
            hf_config.model_type = "deepseek_mtp"
        if hf_config.model_type == "deepseek_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
199
200
201
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
            )
202
203
204
205
206
207
208
        if hf_config.model_type in ("pangu_ultra_moe"):
            hf_config.model_type = "pangu_ultra_moe_mtp"
        if hf_config.model_type == "pangu_ultra_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
            )
209
210
211
212

        if hf_config.architectures[0] == "MiMoForCausalLM":
            hf_config.model_type = "mimo_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
213
214
215
216
217
218
219
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["MiMoMTPModel"],
                }
            )
220
221
222
223

        if hf_config.architectures[0] == "Glm4MoeForCausalLM":
            hf_config.model_type = "glm4_moe_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
224
225
226
227
228
229
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeMTPModel"],
                }
            )
230

231
232
233
234
235
236
237
238
239
240
241
        if hf_config.architectures[0] == "Glm4MoeLiteForCausalLM":
            hf_config.model_type = "glm4_moe_lite_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeLiteMTPModel"],
                }
            )

242
243
244
245
246
247
248
249
250
251
252
        if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
            hf_config.model_type = "glm_ocr_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["GlmOcrMTPModel"],
                }
            )

253
254
255
256
        if hf_config.model_type == "ernie4_5_moe":
            hf_config.model_type = "ernie_mtp"
        if hf_config.model_type == "ernie_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
257
258
259
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
            )
260
261
262
263
264

        if hf_config.model_type == "qwen3_next":
            hf_config.model_type = "qwen3_next_mtp"
        if hf_config.model_type == "qwen3_next_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
265
266
267
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
            )
Kyungmin Lee's avatar
Kyungmin Lee committed
268
269
270
271
272
273
274
275
276

        if hf_config.model_type == "exaone_moe":
            hf_config.model_type = "exaone_moe_mtp"
        if hf_config.model_type == "exaone_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
            )

277
278
279
280
281
282
283
284
285
286
        if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
            is_moe = hf_config.model_type == "qwen3_5_moe"
            hf_config.model_type = "qwen3_5_mtp"
            n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
                }
            )
XuruiYang's avatar
XuruiYang committed
287
288
289
        if hf_config.model_type == "longcat_flash":
            hf_config.model_type = "longcat_flash_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
290
291
292
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
            )
293

csy0225's avatar
csy0225 committed
294
295
296
297
298
        if hf_config.model_type == "step3p5":
            hf_config.model_type = "step3p5_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
            hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})

299
300
301
        if initial_architecture == "MistralLarge3ForCausalLM":
            hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})

302
303
304
305
306
307
308
309
310
311
312
        return hf_config

    def __post_init__(self):
        # Note: "method" is a new parameter that helps to extend the
        # configuration of non-model-based proposers, and the "model" parameter
        # will be used to set the draft model, eagle head, or additional weight
        # when needed. If users do not specify "method", the speculative method
        # will be detected automatically if possible. If the speculative method
        # can not be detected, it will be considered as the "draft_model" by
        # default.

313
314
315
316
317
318
319
        # infer method from user args
        if self.method is None:
            if self.model in ("ngram", "[ngram]"):
                self.method = "ngram"
            else:
                self.method = "draft_model"

320
        if self.method in get_args(MTPModelTypes) and self.method != "mtp":
321
322
323
            logger.warning(
                "method `%s` is deprecated and replaced with mtp.", self.method
            )
324
325
            self.method = "mtp"

326
        if self.model is None and self.num_speculative_tokens is not None:
327
            if self.method == "mtp":
328
329
                if self.target_model_config is None:
                    raise ValueError("target_model_config must be present for mtp")
330
                if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
331
332
333
                    # FIXME(luccafong): cudgraph with v32 MTP is not supported,
                    # remove this when the issue is fixed.
                    self.enforce_eager = True
334
335
336
337
338
339
340
341
                # use the draft model from the same model:
                self.model = self.target_model_config.model
                # Align the quantization of draft model for cases such as
                # --quantization fp8 with a bf16 checkpoint.
                if not self.quantization:
                    self.quantization = self.target_model_config.quantization
            elif self.method in ("ngram", "[ngram]"):
                self.model = "ngram"
342
343
            elif self.method == "suffix":
                self.model = "suffix"
344
            else:
345
                raise ValueError(
346
347
                    "num_speculative_tokens was provided but without speculative model."
                )
348
349
350
351
352

        if self.method in ("ngram", "[ngram]"):
            # Unified to "ngram" internally
            self.method = "ngram"
            # Set default values if not provided
353
            if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
354
355
356
357
                # TODO(woosuk): Tune these values. They are arbitrarily chosen.
                self.prompt_lookup_min = 5
                self.prompt_lookup_max = 5
            elif self.prompt_lookup_min is None:
358
359
360
361
362
                if self.prompt_lookup_max is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
363
364
                self.prompt_lookup_min = self.prompt_lookup_max
            elif self.prompt_lookup_max is None:
365
366
367
368
369
                if self.prompt_lookup_min is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
370
371
372
373
374
375
                self.prompt_lookup_max = self.prompt_lookup_min

            # Validate values
            if self.prompt_lookup_min > self.prompt_lookup_max:
                raise ValueError(
                    f"prompt_lookup_min={self.prompt_lookup_min} must "
376
377
                    f"be <= prompt_lookup_max={self.prompt_lookup_max}"
                )
378
379
380
381
382
383

            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            self.draft_model_config = self.target_model_config
            self.draft_parallel_config = self.target_parallel_config
384
385
        elif self.method == "suffix":
            self._validate_suffix_decoding()
386
387
388
389
390
391
392
393
        else:
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if self.model is not None:
                self.draft_model_config = ModelConfig(
                    model=self.model,
                    runner="draft",
394
395
                    tokenizer=self.target_model_config.tokenizer,
                    tokenizer_mode=self.target_model_config.tokenizer_mode,
396
                    trust_remote_code=self.target_model_config.trust_remote_code,
397
398
                    allowed_local_media_path=self.target_model_config.allowed_local_media_path,
                    allowed_media_domains=self.target_model_config.allowed_media_domains,
399
400
401
402
                    dtype=self.target_model_config.dtype,
                    seed=self.target_model_config.seed,
                    revision=self.revision,
                    code_revision=self.code_revision,
403
                    tokenizer_revision=self.target_model_config.tokenizer_revision,
404
                    spec_target_max_model_len=self.target_model_config.max_model_len,
405
406
407
408
                    quantization=self.quantization,
                    enforce_eager=self.target_model_config.enforce_eager,
                    max_logprobs=self.target_model_config.max_logprobs,
                    hf_overrides=SpeculativeConfig.hf_config_override,
409
                    config_format=self.target_model_config.config_format,
410
411
412
                )

                # Automatically detect the method
413
                if self.method in ("eagle", "eagle3"):
414
415
416
417
418
419
420
421
422
423
424
                    pass
                # examples:
                # yuhuili/EAGLE-LLaMA3-Instruct-8B
                # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
                # AngelSlim/Qwen3-8B_eagle3
                elif "eagle-" in self.draft_model_config.model.lower():
                    self.method = "eagle"
                elif "eagle3" in self.draft_model_config.model.lower():
                    self.method = "eagle3"
                elif self.draft_model_config.hf_config.model_type == "medusa":
                    self.method = "medusa"
425
                elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
426
                    self.method = "mlp_speculator"
427
428
429
                elif self.draft_model_config.hf_config.model_type in get_args(
                    MTPModelTypes
                ):
430
                    self.method = "mtp"
431
432
                    if self.num_speculative_tokens > 1:
                        logger.warning(
433
434
435
436
437
438
439
                            "Enabling num_speculative_tokens > 1 will run"
                            "multiple times of forward on same MTP layer"
                            ",which may result in lower acceptance rate"
                        )
                elif self.draft_model_config.hf_config.model_type in (
                    "longcat_flash_mtp"
                ):
XuruiYang's avatar
XuruiYang committed
440
441
442
                    self.method = "longcat_flash_mtp"
                    if self.num_speculative_tokens > 1:
                        logger.warning(
443
444
445
446
                            "LongCat MTP models only have "
                            "one layer. Might need some code changes "
                            "to support multiple layers."
                        )
447
448
                elif self.method == "draft_model":
                    pass
449
450
                else:
                    raise NotImplementedError(
451
                        f"Unsupported speculative method: '{self.method}'"
452
                    )
453
454
455

                # Replace hf_config for EAGLE draft_model
                if self.method in ("eagle", "eagle3"):
456
457
                    from vllm.transformers_utils.configs import SpeculatorsConfig
                    from vllm.transformers_utils.configs.eagle import EAGLEConfig
458

459
460
461
462
                    if isinstance(
                        self.draft_model_config.hf_config,
                        (EAGLEConfig, SpeculatorsConfig),
                    ):
463
464
465
466
467
                        pass
                    else:
                        eagle_config = EAGLEConfig(
                            self.draft_model_config.hf_config,
                            method=self.method,
468
469
                            model_type="eagle",
                        )
470
471
                        # EAGLEConfig primarily updates architectures, so update
                        # all architectures-related fields in draft_model_config
472
                        self.draft_model_config.hf_config = eagle_config
473
474
475
                        self.draft_model_config.hf_text_config = get_hf_text_config(
                            self.draft_model_config.hf_config
                        )
476
477
478
                        self.draft_model_config.model_arch_config = (
                            self.draft_model_config.get_model_arch_config()
                        )
479
480
481
482
483
484
485
486
                        model_info, arch = (
                            self.draft_model_config.registry.inspect_model_cls(
                                self.draft_model_config.architectures,
                                self.draft_model_config,
                            )
                        )
                        self.draft_model_config._model_info = model_info
                        self.draft_model_config._architecture = arch
487

488
489
490
491
492
493
                if self.num_speculative_tokens is not None and hasattr(
                    self.draft_model_config.hf_config, "num_lookahead_tokens"
                ):
                    self.draft_model_config.hf_config.num_lookahead_tokens = (
                        self.num_speculative_tokens
                    )
494

495
496
497
                n_predict = getattr(
                    self.draft_model_config.hf_config, "n_predict", None
                )
498
499
500
501
                if n_predict is not None:
                    if self.num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        self.num_speculative_tokens = n_predict
502
503
504
505
                    elif (
                        self.num_speculative_tokens > n_predict
                        and self.num_speculative_tokens % n_predict != 0
                    ):
506
507
508
                        # Ensure divisibility for MTP module reuse.
                        raise ValueError(
                            f"num_speculative_tokens:{self.num_speculative_tokens}"
509
510
                            f" must be divisible by {n_predict=}"
                        )
511
512

                if self.speculative_token_tree is None:
513
514
515
516
517
518
519
                    if self.num_speculative_tokens is None:
                        raise ValueError(
                            "A speculative model was provided, but neither "
                            "`speculative_token_tree` nor `num_speculative_tokens` "
                            "was provided"
                        )

520
                    # Generate chain of tokens.
521
522
523
                    self.speculative_token_tree = str(
                        [(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
                    )
524
525
                else:
                    # Sort the token tree breadth-first.
526
                    tree_choices = ast.literal_eval(self.speculative_token_tree)
527
                    self.speculative_token_tree = str(
528
529
                        sorted(tree_choices, key=lambda t: (len(t), t))
                    )
530

531
                self.draft_tensor_parallel_size = (
532
533
534
                    SpeculativeConfig._verify_and_get_draft_tp(
                        self.target_parallel_config,
                        self.draft_tensor_parallel_size,
535
536
                        self.draft_model_config.hf_config,
                    )
537
538
539
540
541
542
543
                )

                self.draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        self.max_model_len,
                        self.draft_model_config.max_model_len,
                        self.target_model_config.max_model_len,
544
545
                    )
                )
546
547
548

                self.draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
549
550
551
                        self.target_parallel_config, self.draft_tensor_parallel_size
                    )
                )
552
        return self
553

554
555
556
557
    def _validate_suffix_decoding(self):
        if not has_arctic_inference():
            raise ImportError(
                "Arctic Inference is required for suffix decoding. "
558
                "Install via `pip install arctic-inference==0.1.1`."
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
            )
        if self.num_speculative_tokens is None:
            # Suffix decoding decides the actual number of speculative tokens
            # dynamically and treats num_speculative_tokens as a maximum limit.
            self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
            logger.warning(
                "Defaulted num_speculative_tokens to %s for suffix decoding.",
                self.num_speculative_tokens,
            )
        # Validate values
        if self.suffix_decoding_max_tree_depth < 1:
            raise ValueError(
                f"suffix_decoding_max_tree_depth="
                f"{self.suffix_decoding_max_tree_depth} must be >= 1"
            )
        if self.suffix_decoding_max_cached_requests < 0:
            raise ValueError(
                f"suffix_decoding_max_cached_requests="
                f"{self.suffix_decoding_max_cached_requests} must be >= 0"
            )
        if self.suffix_decoding_max_spec_factor < 0:
            raise ValueError(
                f"suffix_decoding_max_spec_factor="
                f"{self.suffix_decoding_max_spec_factor} must be >= 0"
            )
        if not 0 <= self.suffix_decoding_min_token_prob <= 1:
            raise ValueError(
                f"suffix_decoding_min_token_prob="
                f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
            )

590
591
    @staticmethod
    def _maybe_override_draft_max_model_len(
592
        speculative_max_model_len: int | None,
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:
            if speculative_max_model_len > draft_max_model_len:
610
611
612
613
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {draft_max_model_len=}"
                )
614
615

            if speculative_max_model_len > target_max_model_len:
616
617
618
619
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {target_max_model_len=}"
                )
620
621
622
623
624
625
626
627
628
629

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

    @staticmethod
    def _verify_and_get_draft_tp(
630
        target_parallel_config: ParallelConfig,
631
        speculative_draft_tensor_parallel_size: int | None,
632
633
        draft_hf_config: PretrainedConfig,
    ) -> int:
634
635
636
637
638
639
640
641
642
643
644
645
646
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
        """
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
        if speculative_draft_tensor_parallel_size is None:
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "%s cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1",
647
648
                        draft_hf_config.model_type,
                    )
649
            else:
650
                speculative_draft_tensor_parallel_size = (
651
                    target_parallel_config.tensor_parallel_size
652
                )
653
        elif speculative_draft_tensor_parallel_size not in (
654
655
656
            1,
            target_parallel_config.tensor_parallel_size,
        ):
657
658
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be "
659
660
                f"other value than 1 or target model tensor_parallel_size"
            )
661
662
663
664
665
666
667
668
669
670
671
672
        return speculative_draft_tensor_parallel_size

    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
        draft_parallel_config = ParallelConfig(
673
            pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
674
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
675
676
677
678
            distributed_executor_backend=target_parallel_config.distributed_executor_backend,
            max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
            ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
679
680
681
682
683
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

684
    @model_validator(mode="after")
685
    def _verify_args(self) -> Self:
686
687
688
689
690
691
        if self.tensor_parallel_size is not None:
            raise ValueError(
                "'tensor_parallel_size' is not a valid argument in the "
                "speculative_config. Please pass 'draft_tensor_parallel_size' instead."
            )

692
693
694
695
        if self.num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative model unless the draft model config contains an "
696
697
                "n_predict parameter."
            )
698
699

        if self.num_speculative_tokens <= 0:
700
701
702
703
            raise ValueError(
                "Expected num_speculative_tokens to be greater "
                f"than zero ({self.num_speculative_tokens})."
            )
704
705
706

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
707
708
                self.draft_parallel_config
            )
709

710
711
712
713
714
715
        if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2:
            raise ValueError(
                "Expect the batch size threshold of disabling "
                "speculative decoding is > 1, but got "
                f"{self.disable_by_batch_size=}"
            )
716

717
718
719
720
721
722
723
        eagle3_target_supported = [
            "llama",
            "qwen",
            "minicpm",
            "gpt_oss",
            "hunyuan_vl",
            "hunyuan_v1_dense",
724
            "afmoe",
725
        ]
726
727
728
729
730
731
732
733
        if (
            self.method == "eagle3"
            and self.target_model_config
            and not any(
                supported_model in self.target_model_config.hf_text_config.model_type
                for supported_model in eagle3_target_supported
            )
        ):
734
735
            raise ValueError(
                f"Eagle3 is only supported for {eagle3_target_supported} models. "  # noqa: E501
736
737
                f"Got {self.target_model_config.hf_text_config.model_type=}"
            )
738
        self.verify_equal_vocab_size_if_draft_model()
739
740
        return self

741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
    def verify_equal_vocab_size_if_draft_model(self):
        if (
            self.method == "draft_model"
            and self.target_model_config is not None
            and self.draft_model_config is not None
        ):
            target_vocab_size = self.target_model_config.get_vocab_size()
            draft_vocab_size = self.draft_model_config.get_vocab_size()
            if target_vocab_size != draft_vocab_size:
                raise ValueError(
                    f"Target and draft model should have the same vocabulary size. "
                    f"Target model vocab_size={target_vocab_size}. "
                    f"Draft model vocab_size={draft_vocab_size}. "
                    f"Using models with different tokenizers can cause out-of-bounds "
                    f"errors during speculative decoding."
                )

758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    @property
    def max_num_new_slots_for_drafting(self) -> int:
        """
        Calculate the maximum number of new slots that might be added to the batch
        when drafting.
        """
        slots_per_req = 0  # for serial non-draft-model methods, no change needed
        if self.parallel_drafting:
            # For parallel drafting, we need one new slot per 'masked' token
            slots_per_req = self.num_speculative_tokens - 1
        if self.uses_draft_model():
            # For draft model-based speculation, we need one new slot per request
            # Since we do not slice the draft tokens
            slots_per_req += 1
        return slots_per_req

774
    def use_eagle(self) -> bool:
775
        return self.method in ("eagle", "eagle3", "mtp")
776

777
778
779
    def uses_draft_model(self) -> bool:
        return self.method == "draft_model"

780
781
    def __repr__(self) -> str:
        method = self.method
782
        model = None if method in ("ngram", "suffix") else self.draft_model_config.model
783
784
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"