speculative.py 29.4 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
5
from typing import TYPE_CHECKING, Any, Literal, get_args
6

7
from pydantic import Field, SkipValidation, model_validator
8
9
10
from pydantic.dataclasses import dataclass
from typing_extensions import Self

11
from vllm.config.model import ModelConfig
12
13
14
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
15
from vllm.transformers_utils.config import get_hf_text_config
16
from vllm.utils.hashing import safe_hash
17
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
18
19
20
21
22
23
24
25

if TYPE_CHECKING:
    from transformers import PretrainedConfig

    import vllm.model_executor.layers.quantization as me_quant
else:
    PretrainedConfig = Any

26
27
28
    me_quant = LazyLoader(
        "model_executor", globals(), "vllm.model_executor.layers.quantization"
    )
29
30
31

logger = init_logger(__name__)

32
MTPModelTypes = Literal[
33
34
35
36
    "deepseek_mtp",
    "mimo_mtp",
    "glm4_moe_mtp",
    "ernie_mtp",
Kyungmin Lee's avatar
Kyungmin Lee committed
37
    "exaone_moe_mtp",
38
39
    "qwen3_next_mtp",
    "longcat_flash_mtp",
40
    "mtp",
41
    "pangu_ultra_moe_mtp",
42
43
44
45
46
47
48
49
50
51
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
    "ngram",
    "medusa",
    "mlp_speculator",
    "draft_model",
    "suffix",
    EagleModelTypes,
]
52
53
54
55
56
57


@config
@dataclass
class SpeculativeConfig:
    """Configuration for speculative decoding."""
58

59
    enforce_eager: bool | None = None
60
    """Override the default enforce_eager from model_config"""
61
    # General speculative decoding control
62
    num_speculative_tokens: int = Field(default=None, gt=0)
63
64
    """The number of speculative tokens, if provided. It will default to the
    number in the draft model config if present, otherwise, it is required."""
65
    model: str | None = None
66
67
    """The name of the draft model, eagle head, or additional weights, if
    provided."""
68
    method: SpeculativeMethod | None = None
69
70
71
72
73
74
75
    """The name of the speculative method to use. If users provide and set the
    `model` param, the speculative method type will be detected automatically
    if possible, if `model` param is not provided, the method name must be
    provided.

    If using `ngram` method, the related configuration `prompt_lookup_max` and
    `prompt_lookup_min` should be considered."""
76
    draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
77
78
79
80
    """The degree of the tensor parallelism for the draft model. Can only be 1
    or the same as the target model's tensor parallel size."""

    # Draft model configuration
81
    quantization: me_quant.QuantizationMethods | None = None
82
83
84
    """Quantization method that was used to quantize the draft model weights.
    If `None`, we assume the model weights are not quantized. Note that it only
    takes effect when using the draft model-based speculative method."""
85
    max_model_len: int | None = Field(default=None, ge=1)
86
87
    """The maximum model length of the draft model. Used when testing the
    ability to skip speculation for some sequences."""
88
    revision: str | None = None
89
90
91
    """The specific model version to use for the draft model. It can be a
    branch name, a tag name, or a commit id. If unspecified, will use the
    default version."""
92
    code_revision: str | None = None
93
94
95
96
97
    """The specific revision to use for the draft model code on Hugging Face
    Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
    will use the default version."""

    # Advanced control
98
    disable_by_batch_size: int | None = Field(default=None, ge=2)
99
100
    """Disable speculative decoding for new incoming requests when the number
    of enqueued requests is larger than this value, if provided."""
101
102
103
104
105
    disable_padded_drafter_batch: bool = False
    """Disable input padding for speculative decoding. If set to True,
    speculative input batches can contain sequences of different lengths,
    which may only be supported by certain attention backends. This currently
    only affects the EAGLE method of speculation."""
106
107

    # Ngram proposer configuration
108
    prompt_lookup_max: int | None = Field(default=None, ge=1)
109
110
    """Maximum size of ngram token window when using Ngram proposer, required
    when method is set to ngram."""
111
    prompt_lookup_min: int | None = Field(default=None, ge=1)
112
113
114
    """Minimum size of ngram token window when using Ngram proposer, if
    provided. Defaults to 1."""

115
    speculative_token_tree: str | None = None
116
117
118
119
120
    """Specifies the tree structure for speculative token generation.
    """
    # required configuration params passed from engine
    target_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the target model."""
121
    target_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
122
123
124
125
126
    """The parallel configuration for the target model."""

    # params generated in the post-init stage
    draft_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the draft model initialized internal."""
127
    draft_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
128
129
    """The parallel configuration for the draft model initialized internal."""

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    # Suffix decoding configuration
    suffix_decoding_max_tree_depth: int = 24
    """The maximum depth of the suffix decoding global and prompt trees. The
    tree depth limits the sum of the prefix match and speculation lengths."""

    suffix_decoding_max_cached_requests: int = 10000
    """The maximum number of requests to cache in the global suffix tree. If
    exceeded, will trigger eviction in FIFO order. If set to 0, the global
    suffix tree is disabled and past responses are not cached (prompt trees
    are still used)."""

    suffix_decoding_max_spec_factor: float = 1.0
    """The maximum spec factor for suffix decoding. The spec factor controls
    speculation lengths based on the prefix match length: max_spec_tokens =
    max_spec_factor * prefix_match_length."""

    suffix_decoding_min_token_prob: float = 0.1
    """The minimum token probability for suffix decoding. Will only speculate
    tokens with estimated probability (based on frequency counts) greater than
    or equal to this value."""

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        # Eagle3 affects the computation graph because it returns intermediate
        # hidden states in addition to the final hidden state.
        factors.append(self.method == "eagle3")
167
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
168
169
170
171
        return hash_str

    @staticmethod
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
172
        initial_architecture = hf_config.architectures[0]
173
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
174
175
176
            hf_config.model_type = "deepseek_mtp"
        if hf_config.model_type == "deepseek_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
177
178
179
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
            )
180
181
182
183
184
185
186
        if hf_config.model_type in ("pangu_ultra_moe"):
            hf_config.model_type = "pangu_ultra_moe_mtp"
        if hf_config.model_type == "pangu_ultra_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
            )
187
188
189
190

        if hf_config.architectures[0] == "MiMoForCausalLM":
            hf_config.model_type = "mimo_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
191
192
193
194
195
196
197
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["MiMoMTPModel"],
                }
            )
198
199
200
201

        if hf_config.architectures[0] == "Glm4MoeForCausalLM":
            hf_config.model_type = "glm4_moe_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
202
203
204
205
206
207
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeMTPModel"],
                }
            )
208
209
210
211
212

        if hf_config.model_type == "ernie4_5_moe":
            hf_config.model_type = "ernie_mtp"
        if hf_config.model_type == "ernie_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
213
214
215
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
            )
216
217
218
219
220

        if hf_config.model_type == "qwen3_next":
            hf_config.model_type = "qwen3_next_mtp"
        if hf_config.model_type == "qwen3_next_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
221
222
223
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
            )
Kyungmin Lee's avatar
Kyungmin Lee committed
224
225
226
227
228
229
230
231
232

        if hf_config.model_type == "exaone_moe":
            hf_config.model_type = "exaone_moe_mtp"
        if hf_config.model_type == "exaone_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
            )

XuruiYang's avatar
XuruiYang committed
233
234
235
        if hf_config.model_type == "longcat_flash":
            hf_config.model_type = "longcat_flash_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
236
237
238
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
            )
239

240
241
242
        if initial_architecture == "MistralLarge3ForCausalLM":
            hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})

243
244
245
246
247
248
249
250
251
252
253
        return hf_config

    def __post_init__(self):
        # Note: "method" is a new parameter that helps to extend the
        # configuration of non-model-based proposers, and the "model" parameter
        # will be used to set the draft model, eagle head, or additional weight
        # when needed. If users do not specify "method", the speculative method
        # will be detected automatically if possible. If the speculative method
        # can not be detected, it will be considered as the "draft_model" by
        # default.

254
        if self.method in get_args(MTPModelTypes) and self.method != "mtp":
255
256
257
            logger.warning(
                "method `%s` is deprecated and replaced with mtp.", self.method
            )
258
259
            self.method = "mtp"

260
        if self.model is None and self.num_speculative_tokens is not None:
261
            if self.method == "mtp":
262
263
                if self.target_model_config is None:
                    raise ValueError("target_model_config must be present for mtp")
264
                if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
265
266
267
                    # FIXME(luccafong): cudgraph with v32 MTP is not supported,
                    # remove this when the issue is fixed.
                    self.enforce_eager = True
268
269
270
271
272
273
274
275
                # use the draft model from the same model:
                self.model = self.target_model_config.model
                # Align the quantization of draft model for cases such as
                # --quantization fp8 with a bf16 checkpoint.
                if not self.quantization:
                    self.quantization = self.target_model_config.quantization
            elif self.method in ("ngram", "[ngram]"):
                self.model = "ngram"
276
277
            elif self.method == "suffix":
                self.model = "suffix"
278
            else:
279
                raise ValueError(
280
281
                    "num_speculative_tokens was provided but without speculative model."
                )
282
283
284

        # Automatically configure the method for ngram when "model" is used
        # instead of "method"
285
286
287
        if self.method is None and (
            self.model is not None and self.model in ("ngram", "[ngram]")
        ):
288
289
290
291
292
293
            self.method = "ngram"

        if self.method in ("ngram", "[ngram]"):
            # Unified to "ngram" internally
            self.method = "ngram"
            # Set default values if not provided
294
            if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
295
296
297
298
                # TODO(woosuk): Tune these values. They are arbitrarily chosen.
                self.prompt_lookup_min = 5
                self.prompt_lookup_max = 5
            elif self.prompt_lookup_min is None:
299
300
301
302
303
                if self.prompt_lookup_max is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
304
305
                self.prompt_lookup_min = self.prompt_lookup_max
            elif self.prompt_lookup_max is None:
306
307
308
309
310
                if self.prompt_lookup_min is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
311
312
313
314
315
316
                self.prompt_lookup_max = self.prompt_lookup_min

            # Validate values
            if self.prompt_lookup_min > self.prompt_lookup_max:
                raise ValueError(
                    f"prompt_lookup_min={self.prompt_lookup_min} must "
317
318
                    f"be <= prompt_lookup_max={self.prompt_lookup_max}"
                )
319
320
321
322
323
324

            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            self.draft_model_config = self.target_model_config
            self.draft_parallel_config = self.target_parallel_config
325
326
        elif self.method == "suffix":
            self._validate_suffix_decoding()
327
328
329
330
331
332
333
334
        else:
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if self.model is not None:
                self.draft_model_config = ModelConfig(
                    model=self.model,
                    runner="draft",
335
336
                    tokenizer=self.target_model_config.tokenizer,
                    tokenizer_mode=self.target_model_config.tokenizer_mode,
337
                    trust_remote_code=self.target_model_config.trust_remote_code,
338
339
                    allowed_local_media_path=self.target_model_config.allowed_local_media_path,
                    allowed_media_domains=self.target_model_config.allowed_media_domains,
340
341
342
343
                    dtype=self.target_model_config.dtype,
                    seed=self.target_model_config.seed,
                    revision=self.revision,
                    code_revision=self.code_revision,
344
                    tokenizer_revision=self.target_model_config.tokenizer_revision,
345
                    spec_target_max_model_len=self.target_model_config.max_model_len,
346
347
348
349
                    quantization=self.quantization,
                    enforce_eager=self.target_model_config.enforce_eager,
                    max_logprobs=self.target_model_config.max_logprobs,
                    hf_overrides=SpeculativeConfig.hf_config_override,
350
                    config_format=self.target_model_config.config_format,
351
352
353
                )

                # Automatically detect the method
354
                if self.method in ("eagle", "eagle3"):
355
356
357
358
359
360
361
362
363
364
365
                    pass
                # examples:
                # yuhuili/EAGLE-LLaMA3-Instruct-8B
                # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
                # AngelSlim/Qwen3-8B_eagle3
                elif "eagle-" in self.draft_model_config.model.lower():
                    self.method = "eagle"
                elif "eagle3" in self.draft_model_config.model.lower():
                    self.method = "eagle3"
                elif self.draft_model_config.hf_config.model_type == "medusa":
                    self.method = "medusa"
366
                elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
367
                    self.method = "mlp_speculator"
368
369
370
                elif self.draft_model_config.hf_config.model_type in get_args(
                    MTPModelTypes
                ):
371
                    self.method = "mtp"
372
373
                    if self.num_speculative_tokens > 1:
                        logger.warning(
374
375
376
377
378
379
380
                            "Enabling num_speculative_tokens > 1 will run"
                            "multiple times of forward on same MTP layer"
                            ",which may result in lower acceptance rate"
                        )
                elif self.draft_model_config.hf_config.model_type in (
                    "longcat_flash_mtp"
                ):
XuruiYang's avatar
XuruiYang committed
381
382
383
                    self.method = "longcat_flash_mtp"
                    if self.num_speculative_tokens > 1:
                        logger.warning(
384
385
386
387
                            "LongCat MTP models only have "
                            "one layer. Might need some code changes "
                            "to support multiple layers."
                        )
388
389
390
391
392
393
                else:
                    self.method = "draft_model"
                    raise NotImplementedError(
                        "Speculative decoding with draft model is not "
                        "supported yet. Please consider using other "
                        "speculative decoding methods such as ngram, medusa, "
394
395
                        "eagle, or mtp."
                    )
396
397
398

                # Replace hf_config for EAGLE draft_model
                if self.method in ("eagle", "eagle3"):
399
400
                    from vllm.transformers_utils.configs import SpeculatorsConfig
                    from vllm.transformers_utils.configs.eagle import EAGLEConfig
401

402
403
404
405
                    if isinstance(
                        self.draft_model_config.hf_config,
                        (EAGLEConfig, SpeculatorsConfig),
                    ):
406
407
408
409
410
                        pass
                    else:
                        eagle_config = EAGLEConfig(
                            self.draft_model_config.hf_config,
                            method=self.method,
411
412
                            model_type="eagle",
                        )
413
414
                        # EAGLEConfig primarily updates architectures, so update
                        # all architectures-related fields in draft_model_config
415
                        self.draft_model_config.hf_config = eagle_config
416
417
418
                        self.draft_model_config.hf_text_config = get_hf_text_config(
                            self.draft_model_config.hf_config
                        )
419
420
421
                        self.draft_model_config.model_arch_config = (
                            self.draft_model_config.get_model_arch_config()
                        )
422
423
424
425
426
427
428
429
                        model_info, arch = (
                            self.draft_model_config.registry.inspect_model_cls(
                                self.draft_model_config.architectures,
                                self.draft_model_config,
                            )
                        )
                        self.draft_model_config._model_info = model_info
                        self.draft_model_config._architecture = arch
430

431
432
433
434
435
436
                if self.num_speculative_tokens is not None and hasattr(
                    self.draft_model_config.hf_config, "num_lookahead_tokens"
                ):
                    self.draft_model_config.hf_config.num_lookahead_tokens = (
                        self.num_speculative_tokens
                    )
437

438
439
440
                n_predict = getattr(
                    self.draft_model_config.hf_config, "n_predict", None
                )
441
442
443
444
                if n_predict is not None:
                    if self.num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        self.num_speculative_tokens = n_predict
445
446
447
448
                    elif (
                        self.num_speculative_tokens > n_predict
                        and self.num_speculative_tokens % n_predict != 0
                    ):
449
450
451
                        # Ensure divisibility for MTP module reuse.
                        raise ValueError(
                            f"num_speculative_tokens:{self.num_speculative_tokens}"
452
453
                            f" must be divisible by {n_predict=}"
                        )
454
455
456

                if self.speculative_token_tree is None:
                    # Generate chain of tokens.
457
458
459
                    self.speculative_token_tree = str(
                        [(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
                    )
460
461
                else:
                    # Sort the token tree breadth-first.
462
                    tree_choices = ast.literal_eval(self.speculative_token_tree)
463
                    self.speculative_token_tree = str(
464
465
                        sorted(tree_choices, key=lambda t: (len(t), t))
                    )
466

467
                self.draft_tensor_parallel_size = (
468
469
470
                    SpeculativeConfig._verify_and_get_draft_tp(
                        self.target_parallel_config,
                        self.draft_tensor_parallel_size,
471
472
                        self.draft_model_config.hf_config,
                    )
473
474
475
476
477
478
479
                )

                self.draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        self.max_model_len,
                        self.draft_model_config.max_model_len,
                        self.target_model_config.max_model_len,
480
481
                    )
                )
482
483
484

                self.draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
485
486
487
                        self.target_parallel_config, self.draft_tensor_parallel_size
                    )
                )
488
        return self
489

490
491
492
493
    def _validate_suffix_decoding(self):
        if not has_arctic_inference():
            raise ImportError(
                "Arctic Inference is required for suffix decoding. "
494
                "Install via `pip install arctic-inference==0.1.1`."
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
            )
        if self.num_speculative_tokens is None:
            # Suffix decoding decides the actual number of speculative tokens
            # dynamically and treats num_speculative_tokens as a maximum limit.
            self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
            logger.warning(
                "Defaulted num_speculative_tokens to %s for suffix decoding.",
                self.num_speculative_tokens,
            )
        # Validate values
        if self.suffix_decoding_max_tree_depth < 1:
            raise ValueError(
                f"suffix_decoding_max_tree_depth="
                f"{self.suffix_decoding_max_tree_depth} must be >= 1"
            )
        if self.suffix_decoding_max_cached_requests < 0:
            raise ValueError(
                f"suffix_decoding_max_cached_requests="
                f"{self.suffix_decoding_max_cached_requests} must be >= 0"
            )
        if self.suffix_decoding_max_spec_factor < 0:
            raise ValueError(
                f"suffix_decoding_max_spec_factor="
                f"{self.suffix_decoding_max_spec_factor} must be >= 0"
            )
        if not 0 <= self.suffix_decoding_min_token_prob <= 1:
            raise ValueError(
                f"suffix_decoding_min_token_prob="
                f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
            )

526
527
    @staticmethod
    def _maybe_override_draft_max_model_len(
528
        speculative_max_model_len: int | None,
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:
            if speculative_max_model_len > draft_max_model_len:
546
547
548
549
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {draft_max_model_len=}"
                )
550
551

            if speculative_max_model_len > target_max_model_len:
552
553
554
555
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {target_max_model_len=}"
                )
556
557
558
559
560
561
562
563
564
565

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

    @staticmethod
    def _verify_and_get_draft_tp(
566
        target_parallel_config: ParallelConfig,
567
        speculative_draft_tensor_parallel_size: int | None,
568
569
        draft_hf_config: PretrainedConfig,
    ) -> int:
570
571
572
573
574
575
576
577
578
579
580
581
582
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
        """
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
        if speculative_draft_tensor_parallel_size is None:
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "%s cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1",
583
584
                        draft_hf_config.model_type,
                    )
585
            else:
586
                speculative_draft_tensor_parallel_size = (
587
                    target_parallel_config.tensor_parallel_size
588
                )
589
        elif speculative_draft_tensor_parallel_size not in (
590
591
592
            1,
            target_parallel_config.tensor_parallel_size,
        ):
593
594
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be "
595
596
                f"other value than 1 or target model tensor_parallel_size"
            )
597
598
599
600
601
602
603
604
605
606
607
608
        return speculative_draft_tensor_parallel_size

    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
        draft_parallel_config = ParallelConfig(
609
            pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
610
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
611
612
613
614
            distributed_executor_backend=target_parallel_config.distributed_executor_backend,
            max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
            ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
615
616
617
618
619
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

620
    @model_validator(mode="after")
621
622
623
624
625
    def _verify_args(self) -> Self:
        if self.num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative model unless the draft model config contains an "
626
627
                "n_predict parameter."
            )
628
629

        if self.num_speculative_tokens <= 0:
630
631
632
633
            raise ValueError(
                "Expected num_speculative_tokens to be greater "
                f"than zero ({self.num_speculative_tokens})."
            )
634
635
636

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
637
638
                self.draft_parallel_config
            )
639

640
641
642
643
644
645
        if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2:
            raise ValueError(
                "Expect the batch size threshold of disabling "
                "speculative decoding is > 1, but got "
                f"{self.disable_by_batch_size=}"
            )
646

647
        eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
648
649
650
651
652
653
654
655
        if (
            self.method == "eagle3"
            and self.target_model_config
            and not any(
                supported_model in self.target_model_config.hf_text_config.model_type
                for supported_model in eagle3_target_supported
            )
        ):
656
657
            raise ValueError(
                f"Eagle3 is only supported for {eagle3_target_supported} models. "  # noqa: E501
658
659
                f"Got {self.target_model_config.hf_text_config.model_type=}"
            )
660
661
662
663

        return self

    def use_eagle(self) -> bool:
664
        return self.method in ("eagle", "eagle3", "mtp")
665
666
667

    def __repr__(self) -> str:
        method = self.method
668
        model = None if method in ("ngram", "suffix") else self.draft_model_config.model
669
670
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"