speculative.py 32.2 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
5
from typing import TYPE_CHECKING, Any, Literal, get_args
6

7
from pydantic import Field, SkipValidation, model_validator
8
9
from typing_extensions import Self

10
from vllm.config.model import ModelConfig
11
12
13
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
14
from vllm.transformers_utils.config import get_hf_text_config
15
from vllm.utils.hashing import safe_hash
16
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
17
18
19
20
21
22
23
24

if TYPE_CHECKING:
    from transformers import PretrainedConfig

    import vllm.model_executor.layers.quantization as me_quant
else:
    PretrainedConfig = Any

25
26
27
    me_quant = LazyLoader(
        "model_executor", globals(), "vllm.model_executor.layers.quantization"
    )
28
29
30

logger = init_logger(__name__)

31
MTPModelTypes = Literal[
32
33
34
    "deepseek_mtp",
    "mimo_mtp",
    "glm4_moe_mtp",
35
    "glm4_moe_lite_mtp",
36
    "glm_ocr_mtp",
37
    "ernie_mtp",
Kyungmin Lee's avatar
Kyungmin Lee committed
38
    "exaone_moe_mtp",
39
40
    "qwen3_next_mtp",
    "longcat_flash_mtp",
41
    "mtp",
42
    "pangu_ultra_moe_mtp",
csy0225's avatar
csy0225 committed
43
    "step3p5_mtp",
44
45
46
47
48
49
50
51
52
53
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
    "ngram",
    "medusa",
    "mlp_speculator",
    "draft_model",
    "suffix",
    EagleModelTypes,
]
54
55
56
57
58


@config
class SpeculativeConfig:
    """Configuration for speculative decoding."""
59

60
    enforce_eager: bool | None = None
61
    """Override the default enforce_eager from model_config"""
62
    # General speculative decoding control
63
    num_speculative_tokens: int = Field(default=None, gt=0)
64
65
    """The number of speculative tokens, if provided. It will default to the
    number in the draft model config if present, otherwise, it is required."""
66
    model: str | None = None
67
68
    """The name of the draft model, eagle head, or additional weights, if
    provided."""
69
    method: SpeculativeMethod | None = None
70
71
72
73
74
75
76
    """The name of the speculative method to use. If users provide and set the
    `model` param, the speculative method type will be detected automatically
    if possible, if `model` param is not provided, the method name must be
    provided.

    If using `ngram` method, the related configuration `prompt_lookup_max` and
    `prompt_lookup_min` should be considered."""
77
    draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
78
79
    """The degree of the tensor parallelism for the draft model. Can only be 1
    or the same as the target model's tensor parallel size."""
80
81
82
    tensor_parallel_size: int | None = None
    """Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
    warn users when they mistakenly provide the wrong argument."""
83
84

    # Draft model configuration
85
    quantization: me_quant.QuantizationMethods | None = None
86
87
88
    """Quantization method that was used to quantize the draft model weights.
    If `None`, we assume the model weights are not quantized. Note that it only
    takes effect when using the draft model-based speculative method."""
89
    max_model_len: int | None = Field(default=None, ge=1)
90
91
    """The maximum model length of the draft model. Used when testing the
    ability to skip speculation for some sequences."""
92
    revision: str | None = None
93
94
95
    """The specific model version to use for the draft model. It can be a
    branch name, a tag name, or a commit id. If unspecified, will use the
    default version."""
96
    code_revision: str | None = None
97
98
99
100
101
    """The specific revision to use for the draft model code on Hugging Face
    Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
    will use the default version."""

    # Advanced control
102
    disable_by_batch_size: int | None = Field(default=None, ge=2)
103
104
    """Disable speculative decoding for new incoming requests when the number
    of enqueued requests is larger than this value, if provided."""
105
106
107
108
109
    disable_padded_drafter_batch: bool = False
    """Disable input padding for speculative decoding. If set to True,
    speculative input batches can contain sequences of different lengths,
    which may only be supported by certain attention backends. This currently
    only affects the EAGLE method of speculation."""
110
111

    # Ngram proposer configuration
112
    prompt_lookup_max: int | None = Field(default=None, ge=1)
113
114
    """Maximum size of ngram token window when using Ngram proposer, required
    when method is set to ngram."""
115
    prompt_lookup_min: int | None = Field(default=None, ge=1)
116
117
118
    """Minimum size of ngram token window when using Ngram proposer, if
    provided. Defaults to 1."""

119
    # Alternative drafting strategies
120
    speculative_token_tree: str | None = None
121
122
    """Specifies the tree structure for speculative token generation.
    """
123
124
125
126
127
128
    parallel_drafting: bool = False
    """Enable parallel drafting, where all speculative tokens are generated
    in parallel rather than sequentially. This can improve performance but
    requires the speculative model be trained to support parallel drafting.
    Only compatible with EAGLE and draft model methods."""

129
130
131
    # required configuration params passed from engine
    target_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the target model."""
132
    target_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
133
134
135
136
137
    """The parallel configuration for the target model."""

    # params generated in the post-init stage
    draft_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the draft model initialized internal."""
138
    draft_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
139
140
    """The parallel configuration for the draft model initialized internal."""

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    # Suffix decoding configuration
    suffix_decoding_max_tree_depth: int = 24
    """The maximum depth of the suffix decoding global and prompt trees. The
    tree depth limits the sum of the prefix match and speculation lengths."""

    suffix_decoding_max_cached_requests: int = 10000
    """The maximum number of requests to cache in the global suffix tree. If
    exceeded, will trigger eviction in FIFO order. If set to 0, the global
    suffix tree is disabled and past responses are not cached (prompt trees
    are still used)."""

    suffix_decoding_max_spec_factor: float = 1.0
    """The maximum spec factor for suffix decoding. The spec factor controls
    speculation lengths based on the prefix match length: max_spec_tokens =
    max_spec_factor * prefix_match_length."""

    suffix_decoding_min_token_prob: float = 0.1
    """The minimum token probability for suffix decoding. Will only speculate
    tokens with estimated probability (based on frequency counts) greater than
    or equal to this value."""

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        # Eagle3 affects the computation graph because it returns intermediate
        # hidden states in addition to the final hidden state.
        factors.append(self.method == "eagle3")
178
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
179
180
181
182
        return hash_str

    @staticmethod
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
183
        initial_architecture = hf_config.architectures[0]
Jee Jee Li's avatar
Jee Jee Li committed
184
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
185
186
187
            hf_config.model_type = "deepseek_mtp"
        if hf_config.model_type == "deepseek_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
188
189
190
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
            )
191
192
193
194
195
196
197
        if hf_config.model_type in ("pangu_ultra_moe"):
            hf_config.model_type = "pangu_ultra_moe_mtp"
        if hf_config.model_type == "pangu_ultra_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
            )
198
199
200
201

        if hf_config.architectures[0] == "MiMoForCausalLM":
            hf_config.model_type = "mimo_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
202
203
204
205
206
207
208
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["MiMoMTPModel"],
                }
            )
209
210
211
212

        if hf_config.architectures[0] == "Glm4MoeForCausalLM":
            hf_config.model_type = "glm4_moe_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
213
214
215
216
217
218
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeMTPModel"],
                }
            )
219

220
221
222
223
224
225
226
227
228
229
230
        if hf_config.architectures[0] == "Glm4MoeLiteForCausalLM":
            hf_config.model_type = "glm4_moe_lite_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeLiteMTPModel"],
                }
            )

231
232
233
234
235
236
237
238
239
240
241
        if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
            hf_config.model_type = "glm_ocr_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["GlmOcrMTPModel"],
                }
            )

242
243
244
245
        if hf_config.model_type == "ernie4_5_moe":
            hf_config.model_type = "ernie_mtp"
        if hf_config.model_type == "ernie_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
246
247
248
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
            )
249
250
251
252
253

        if hf_config.model_type == "qwen3_next":
            hf_config.model_type = "qwen3_next_mtp"
        if hf_config.model_type == "qwen3_next_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
254
255
256
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
            )
Kyungmin Lee's avatar
Kyungmin Lee committed
257
258
259
260
261
262
263
264
265

        if hf_config.model_type == "exaone_moe":
            hf_config.model_type = "exaone_moe_mtp"
        if hf_config.model_type == "exaone_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
            )

XuruiYang's avatar
XuruiYang committed
266
267
268
        if hf_config.model_type == "longcat_flash":
            hf_config.model_type = "longcat_flash_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
269
270
271
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
            )
272

csy0225's avatar
csy0225 committed
273
274
275
276
277
        if hf_config.model_type == "step3p5":
            hf_config.model_type = "step3p5_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
            hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})

278
279
280
        if initial_architecture == "MistralLarge3ForCausalLM":
            hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})

281
282
283
284
285
286
287
288
289
290
291
        return hf_config

    def __post_init__(self):
        # Note: "method" is a new parameter that helps to extend the
        # configuration of non-model-based proposers, and the "model" parameter
        # will be used to set the draft model, eagle head, or additional weight
        # when needed. If users do not specify "method", the speculative method
        # will be detected automatically if possible. If the speculative method
        # can not be detected, it will be considered as the "draft_model" by
        # default.

292
        if self.method in get_args(MTPModelTypes) and self.method != "mtp":
293
294
295
            logger.warning(
                "method `%s` is deprecated and replaced with mtp.", self.method
            )
296
297
            self.method = "mtp"

298
        if self.model is None and self.num_speculative_tokens is not None:
299
            if self.method == "mtp":
300
301
                if self.target_model_config is None:
                    raise ValueError("target_model_config must be present for mtp")
302
                if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
303
304
305
                    # FIXME(luccafong): cudgraph with v32 MTP is not supported,
                    # remove this when the issue is fixed.
                    self.enforce_eager = True
306
307
308
309
310
311
312
313
                # use the draft model from the same model:
                self.model = self.target_model_config.model
                # Align the quantization of draft model for cases such as
                # --quantization fp8 with a bf16 checkpoint.
                if not self.quantization:
                    self.quantization = self.target_model_config.quantization
            elif self.method in ("ngram", "[ngram]"):
                self.model = "ngram"
314
315
            elif self.method == "suffix":
                self.model = "suffix"
316
            else:
317
                raise ValueError(
318
319
                    "num_speculative_tokens was provided but without speculative model."
                )
320
321
322

        # Automatically configure the method for ngram when "model" is used
        # instead of "method"
323
324
325
        if self.method is None and (
            self.model is not None and self.model in ("ngram", "[ngram]")
        ):
326
327
328
329
330
331
            self.method = "ngram"

        if self.method in ("ngram", "[ngram]"):
            # Unified to "ngram" internally
            self.method = "ngram"
            # Set default values if not provided
332
            if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
333
334
335
336
                # TODO(woosuk): Tune these values. They are arbitrarily chosen.
                self.prompt_lookup_min = 5
                self.prompt_lookup_max = 5
            elif self.prompt_lookup_min is None:
337
338
339
340
341
                if self.prompt_lookup_max is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
342
343
                self.prompt_lookup_min = self.prompt_lookup_max
            elif self.prompt_lookup_max is None:
344
345
346
347
348
                if self.prompt_lookup_min is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
349
350
351
352
353
354
                self.prompt_lookup_max = self.prompt_lookup_min

            # Validate values
            if self.prompt_lookup_min > self.prompt_lookup_max:
                raise ValueError(
                    f"prompt_lookup_min={self.prompt_lookup_min} must "
355
356
                    f"be <= prompt_lookup_max={self.prompt_lookup_max}"
                )
357
358
359
360
361
362

            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            self.draft_model_config = self.target_model_config
            self.draft_parallel_config = self.target_parallel_config
363
364
        elif self.method == "suffix":
            self._validate_suffix_decoding()
365
366
367
368
369
370
371
372
        else:
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if self.model is not None:
                self.draft_model_config = ModelConfig(
                    model=self.model,
                    runner="draft",
373
374
                    tokenizer=self.target_model_config.tokenizer,
                    tokenizer_mode=self.target_model_config.tokenizer_mode,
375
                    trust_remote_code=self.target_model_config.trust_remote_code,
376
377
                    allowed_local_media_path=self.target_model_config.allowed_local_media_path,
                    allowed_media_domains=self.target_model_config.allowed_media_domains,
378
379
380
381
                    dtype=self.target_model_config.dtype,
                    seed=self.target_model_config.seed,
                    revision=self.revision,
                    code_revision=self.code_revision,
382
                    tokenizer_revision=self.target_model_config.tokenizer_revision,
383
                    spec_target_max_model_len=self.target_model_config.max_model_len,
384
385
386
387
                    quantization=self.quantization,
                    enforce_eager=self.target_model_config.enforce_eager,
                    max_logprobs=self.target_model_config.max_logprobs,
                    hf_overrides=SpeculativeConfig.hf_config_override,
388
                    config_format=self.target_model_config.config_format,
389
390
391
                )

                # Automatically detect the method
392
                if self.method in ("eagle", "eagle3"):
393
394
395
396
397
398
399
400
401
402
403
                    pass
                # examples:
                # yuhuili/EAGLE-LLaMA3-Instruct-8B
                # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
                # AngelSlim/Qwen3-8B_eagle3
                elif "eagle-" in self.draft_model_config.model.lower():
                    self.method = "eagle"
                elif "eagle3" in self.draft_model_config.model.lower():
                    self.method = "eagle3"
                elif self.draft_model_config.hf_config.model_type == "medusa":
                    self.method = "medusa"
404
                elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
405
                    self.method = "mlp_speculator"
406
407
408
                elif self.draft_model_config.hf_config.model_type in get_args(
                    MTPModelTypes
                ):
409
                    self.method = "mtp"
410
411
                    if self.num_speculative_tokens > 1:
                        logger.warning(
412
413
414
415
416
417
418
                            "Enabling num_speculative_tokens > 1 will run"
                            "multiple times of forward on same MTP layer"
                            ",which may result in lower acceptance rate"
                        )
                elif self.draft_model_config.hf_config.model_type in (
                    "longcat_flash_mtp"
                ):
XuruiYang's avatar
XuruiYang committed
419
420
421
                    self.method = "longcat_flash_mtp"
                    if self.num_speculative_tokens > 1:
                        logger.warning(
422
423
424
425
                            "LongCat MTP models only have "
                            "one layer. Might need some code changes "
                            "to support multiple layers."
                        )
426
427
                elif self.method == "draft_model":
                    pass
428
429
                else:
                    raise NotImplementedError(
430
                        f"Unsupported speculative method: '{self.method}'"
431
                    )
432
433
434

                # Replace hf_config for EAGLE draft_model
                if self.method in ("eagle", "eagle3"):
435
436
                    from vllm.transformers_utils.configs import SpeculatorsConfig
                    from vllm.transformers_utils.configs.eagle import EAGLEConfig
437

438
439
440
441
                    if isinstance(
                        self.draft_model_config.hf_config,
                        (EAGLEConfig, SpeculatorsConfig),
                    ):
442
443
444
445
446
                        pass
                    else:
                        eagle_config = EAGLEConfig(
                            self.draft_model_config.hf_config,
                            method=self.method,
447
448
                            model_type="eagle",
                        )
449
450
                        # EAGLEConfig primarily updates architectures, so update
                        # all architectures-related fields in draft_model_config
451
                        self.draft_model_config.hf_config = eagle_config
452
453
454
                        self.draft_model_config.hf_text_config = get_hf_text_config(
                            self.draft_model_config.hf_config
                        )
455
456
457
                        self.draft_model_config.model_arch_config = (
                            self.draft_model_config.get_model_arch_config()
                        )
458
459
460
461
462
463
464
465
                        model_info, arch = (
                            self.draft_model_config.registry.inspect_model_cls(
                                self.draft_model_config.architectures,
                                self.draft_model_config,
                            )
                        )
                        self.draft_model_config._model_info = model_info
                        self.draft_model_config._architecture = arch
466

467
468
469
470
471
472
                if self.num_speculative_tokens is not None and hasattr(
                    self.draft_model_config.hf_config, "num_lookahead_tokens"
                ):
                    self.draft_model_config.hf_config.num_lookahead_tokens = (
                        self.num_speculative_tokens
                    )
473

474
475
476
                n_predict = getattr(
                    self.draft_model_config.hf_config, "n_predict", None
                )
477
478
479
480
                if n_predict is not None:
                    if self.num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        self.num_speculative_tokens = n_predict
481
482
483
484
                    elif (
                        self.num_speculative_tokens > n_predict
                        and self.num_speculative_tokens % n_predict != 0
                    ):
485
486
487
                        # Ensure divisibility for MTP module reuse.
                        raise ValueError(
                            f"num_speculative_tokens:{self.num_speculative_tokens}"
488
489
                            f" must be divisible by {n_predict=}"
                        )
490
491
492

                if self.speculative_token_tree is None:
                    # Generate chain of tokens.
493
494
495
                    self.speculative_token_tree = str(
                        [(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
                    )
496
497
                else:
                    # Sort the token tree breadth-first.
498
                    tree_choices = ast.literal_eval(self.speculative_token_tree)
499
                    self.speculative_token_tree = str(
500
501
                        sorted(tree_choices, key=lambda t: (len(t), t))
                    )
502

503
                self.draft_tensor_parallel_size = (
504
505
506
                    SpeculativeConfig._verify_and_get_draft_tp(
                        self.target_parallel_config,
                        self.draft_tensor_parallel_size,
507
508
                        self.draft_model_config.hf_config,
                    )
509
510
511
512
513
514
515
                )

                self.draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        self.max_model_len,
                        self.draft_model_config.max_model_len,
                        self.target_model_config.max_model_len,
516
517
                    )
                )
518
519
520

                self.draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
521
522
523
                        self.target_parallel_config, self.draft_tensor_parallel_size
                    )
                )
524
        return self
525

526
527
528
529
    def _validate_suffix_decoding(self):
        if not has_arctic_inference():
            raise ImportError(
                "Arctic Inference is required for suffix decoding. "
530
                "Install via `pip install arctic-inference==0.1.1`."
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
            )
        if self.num_speculative_tokens is None:
            # Suffix decoding decides the actual number of speculative tokens
            # dynamically and treats num_speculative_tokens as a maximum limit.
            self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
            logger.warning(
                "Defaulted num_speculative_tokens to %s for suffix decoding.",
                self.num_speculative_tokens,
            )
        # Validate values
        if self.suffix_decoding_max_tree_depth < 1:
            raise ValueError(
                f"suffix_decoding_max_tree_depth="
                f"{self.suffix_decoding_max_tree_depth} must be >= 1"
            )
        if self.suffix_decoding_max_cached_requests < 0:
            raise ValueError(
                f"suffix_decoding_max_cached_requests="
                f"{self.suffix_decoding_max_cached_requests} must be >= 0"
            )
        if self.suffix_decoding_max_spec_factor < 0:
            raise ValueError(
                f"suffix_decoding_max_spec_factor="
                f"{self.suffix_decoding_max_spec_factor} must be >= 0"
            )
        if not 0 <= self.suffix_decoding_min_token_prob <= 1:
            raise ValueError(
                f"suffix_decoding_min_token_prob="
                f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
            )

562
563
    @staticmethod
    def _maybe_override_draft_max_model_len(
564
        speculative_max_model_len: int | None,
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:
            if speculative_max_model_len > draft_max_model_len:
582
583
584
585
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {draft_max_model_len=}"
                )
586
587

            if speculative_max_model_len > target_max_model_len:
588
589
590
591
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {target_max_model_len=}"
                )
592
593
594
595
596
597
598
599
600
601

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

    @staticmethod
    def _verify_and_get_draft_tp(
602
        target_parallel_config: ParallelConfig,
603
        speculative_draft_tensor_parallel_size: int | None,
604
605
        draft_hf_config: PretrainedConfig,
    ) -> int:
606
607
608
609
610
611
612
613
614
615
616
617
618
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
        """
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
        if speculative_draft_tensor_parallel_size is None:
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "%s cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1",
619
620
                        draft_hf_config.model_type,
                    )
621
            else:
622
                speculative_draft_tensor_parallel_size = (
623
                    target_parallel_config.tensor_parallel_size
624
                )
625
        elif speculative_draft_tensor_parallel_size not in (
626
627
628
            1,
            target_parallel_config.tensor_parallel_size,
        ):
629
630
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be "
631
632
                f"other value than 1 or target model tensor_parallel_size"
            )
633
634
635
636
637
638
639
640
641
642
643
644
        return speculative_draft_tensor_parallel_size

    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
        draft_parallel_config = ParallelConfig(
645
            pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
646
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
647
648
649
650
            distributed_executor_backend=target_parallel_config.distributed_executor_backend,
            max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
            ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
651
652
653
654
655
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

656
    @model_validator(mode="after")
657
    def _verify_args(self) -> Self:
658
659
660
661
662
663
        if self.tensor_parallel_size is not None:
            raise ValueError(
                "'tensor_parallel_size' is not a valid argument in the "
                "speculative_config. Please pass 'draft_tensor_parallel_size' instead."
            )

664
665
666
667
        if self.num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative model unless the draft model config contains an "
668
669
                "n_predict parameter."
            )
670
671

        if self.num_speculative_tokens <= 0:
672
673
674
675
            raise ValueError(
                "Expected num_speculative_tokens to be greater "
                f"than zero ({self.num_speculative_tokens})."
            )
676
677
678

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
679
680
                self.draft_parallel_config
            )
681

682
683
684
685
686
687
        if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2:
            raise ValueError(
                "Expect the batch size threshold of disabling "
                "speculative decoding is > 1, but got "
                f"{self.disable_by_batch_size=}"
            )
688

689
690
691
692
693
694
695
        eagle3_target_supported = [
            "llama",
            "qwen",
            "minicpm",
            "gpt_oss",
            "hunyuan_vl",
            "hunyuan_v1_dense",
696
            "afmoe",
697
        ]
698
699
700
701
702
703
704
705
        if (
            self.method == "eagle3"
            and self.target_model_config
            and not any(
                supported_model in self.target_model_config.hf_text_config.model_type
                for supported_model in eagle3_target_supported
            )
        ):
706
707
            raise ValueError(
                f"Eagle3 is only supported for {eagle3_target_supported} models. "  # noqa: E501
708
709
                f"Got {self.target_model_config.hf_text_config.model_type=}"
            )
710
        self.verify_equal_vocab_size_if_draft_model()
711
712
        return self

713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
    def verify_equal_vocab_size_if_draft_model(self):
        if (
            self.method == "draft_model"
            and self.target_model_config is not None
            and self.draft_model_config is not None
        ):
            target_vocab_size = self.target_model_config.get_vocab_size()
            draft_vocab_size = self.draft_model_config.get_vocab_size()
            if target_vocab_size != draft_vocab_size:
                raise ValueError(
                    f"Target and draft model should have the same vocabulary size. "
                    f"Target model vocab_size={target_vocab_size}. "
                    f"Draft model vocab_size={draft_vocab_size}. "
                    f"Using models with different tokenizers can cause out-of-bounds "
                    f"errors during speculative decoding."
                )

730
    def use_eagle(self) -> bool:
731
        return self.method in ("eagle", "eagle3", "mtp")
732

733
734
735
    def uses_draft_model(self) -> bool:
        return self.method == "draft_model"

736
737
    def __repr__(self) -> str:
        method = self.method
738
        model = None if method in ("ngram", "suffix") else self.draft_model_config.model
739
740
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"