speculative.py 34.3 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
5
from typing import TYPE_CHECKING, Any, Literal, get_args
6

7
from pydantic import Field, SkipValidation, model_validator
8
9
10
from pydantic.dataclasses import dataclass
from typing_extensions import Self

luopl's avatar
luopl committed
11
from vllm.config import LoadConfig
12
from vllm.config.model import ModelConfig
13
14
15
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
16
from vllm.transformers_utils.config import get_hf_text_config
17
from vllm.utils.hashing import safe_hash
18
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
19
20
21
22
23
24
25
26

if TYPE_CHECKING:
    from transformers import PretrainedConfig

    import vllm.model_executor.layers.quantization as me_quant
else:
    PretrainedConfig = Any

27
28
29
    me_quant = LazyLoader(
        "model_executor", globals(), "vllm.model_executor.layers.quantization"
    )
30
31
32

logger = init_logger(__name__)

33
MTPModelTypes = Literal[
34
35
36
    "deepseek_mtp",
    "mimo_mtp",
    "glm4_moe_mtp",
37
    "glm4_moe_lite_mtp",
38
    "ernie_mtp",
Kyungmin Lee's avatar
Kyungmin Lee committed
39
    "exaone_moe_mtp",
40
    "qwen3_next_mtp",
Rayyyyy's avatar
Rayyyyy committed
41
    "qwen3_5_mtp",
42
    "longcat_flash_mtp",
43
    "mtp",
44
    "pangu_ultra_moe_mtp",
csy0225's avatar
csy0225 committed
45
    "step3p5_mtp",
46
47
48
49
50
51
52
53
54
55
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
    "ngram",
    "medusa",
    "mlp_speculator",
    "draft_model",
    "suffix",
    EagleModelTypes,
]
56
57
58
59
60
61


@config
@dataclass
class SpeculativeConfig:
    """Configuration for speculative decoding."""
62

63
    enforce_eager: bool | None = None
64
    """Override the default enforce_eager from model_config"""
65
    # General speculative decoding control
66
    num_speculative_tokens: int = Field(default=None, gt=0)
67
68
    """The number of speculative tokens, if provided. It will default to the
    number in the draft model config if present, otherwise, it is required."""
69
    model: str | None = None
70
71
    """The name of the draft model, eagle head, or additional weights, if
    provided."""
72
    method: SpeculativeMethod | None = None
73
74
75
76
77
78
79
    """The name of the speculative method to use. If users provide and set the
    `model` param, the speculative method type will be detected automatically
    if possible, if `model` param is not provided, the method name must be
    provided.

    If using `ngram` method, the related configuration `prompt_lookup_max` and
    `prompt_lookup_min` should be considered."""
luopl's avatar
luopl committed
80
81
82
83
    enable_multi_layers_mtp: bool = False
    """If set to True, the MTP method will run multiple layers of MTP
    speculator. If set to False, it will run only one layer of MTP speculator.
    This is only effective when the method is set to `mtp`."""
84
    draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
85
86
    """The degree of the tensor parallelism for the draft model. Can only be 1
    or the same as the target model's tensor parallel size."""
87
88
89
    tensor_parallel_size: int | None = None
    """Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
    warn users when they mistakenly provide the wrong argument."""
90
91

    # Draft model configuration
92
    quantization: me_quant.QuantizationMethods | None = None
93
94
95
    """Quantization method that was used to quantize the draft model weights.
    If `None`, we assume the model weights are not quantized. Note that it only
    takes effect when using the draft model-based speculative method."""
96
    max_model_len: int | None = Field(default=None, ge=1)
97
98
    """The maximum model length of the draft model. Used when testing the
    ability to skip speculation for some sequences."""
99
    revision: str | None = None
100
101
102
    """The specific model version to use for the draft model. It can be a
    branch name, a tag name, or a commit id. If unspecified, will use the
    default version."""
103
    code_revision: str | None = None
104
105
106
107
108
    """The specific revision to use for the draft model code on Hugging Face
    Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
    will use the default version."""

    # Advanced control
109
    disable_by_batch_size: int | None = Field(default=None, ge=2)
110
111
    """Disable speculative decoding for new incoming requests when the number
    of enqueued requests is larger than this value, if provided."""
112
113
114
115
116
    disable_padded_drafter_batch: bool = False
    """Disable input padding for speculative decoding. If set to True,
    speculative input batches can contain sequences of different lengths,
    which may only be supported by certain attention backends. This currently
    only affects the EAGLE method of speculation."""
117

luopl's avatar
luopl committed
118
119
120
121
122
    use_local_argmax_reduction: bool = False
    """Use vocab-parallel local argmax instead of all-gathering full logits
    for draft token generation. Reduces communication from O(vocab_size) to
    O(2 * tp_size) per token. Only applies to greedy draft selection in
    non-tree speculation."""
123
    # Ngram proposer configuration
124
    prompt_lookup_max: int | None = Field(default=None, ge=1)
125
126
    """Maximum size of ngram token window when using Ngram proposer, required
    when method is set to ngram."""
127
    prompt_lookup_min: int | None = Field(default=None, ge=1)
128
129
130
    """Minimum size of ngram token window when using Ngram proposer, if
    provided. Defaults to 1."""

131
    speculative_token_tree: str | None = None
132
133
    """Specifies the tree structure for speculative token generation.
    """
luopl's avatar
luopl committed
134
135
136
137
138
139
    parallel_drafting: bool = False
    """Enable parallel drafting, where all speculative tokens are generated
    in parallel rather than sequentially. This can improve performance but
    requires the speculative model be trained to support parallel drafting.
    Only compatible with EAGLE and draft model methods."""

140
141
142
    # required configuration params passed from engine
    target_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the target model."""
143
    target_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
144
145
146
147
148
    """The parallel configuration for the target model."""

    # params generated in the post-init stage
    draft_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the draft model initialized internal."""
149
    draft_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
150
151
    """The parallel configuration for the draft model initialized internal."""

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    # Suffix decoding configuration
    suffix_decoding_max_tree_depth: int = 24
    """The maximum depth of the suffix decoding global and prompt trees. The
    tree depth limits the sum of the prefix match and speculation lengths."""

    suffix_decoding_max_cached_requests: int = 10000
    """The maximum number of requests to cache in the global suffix tree. If
    exceeded, will trigger eviction in FIFO order. If set to 0, the global
    suffix tree is disabled and past responses are not cached (prompt trees
    are still used)."""

    suffix_decoding_max_spec_factor: float = 1.0
    """The maximum spec factor for suffix decoding. The spec factor controls
    speculation lengths based on the prefix match length: max_spec_tokens =
    max_spec_factor * prefix_match_length."""

    suffix_decoding_min_token_prob: float = 0.1
    """The minimum token probability for suffix decoding. Will only speculate
    tokens with estimated probability (based on frequency counts) greater than
    or equal to this value."""

luopl's avatar
luopl committed
173
174
175
176
    draft_load_config: LoadConfig | None = None
    """Load config for the draft model. If not specified, will use the load
    config from the target model."""

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        # Eagle3 affects the computation graph because it returns intermediate
        # hidden states in addition to the final hidden state.
        factors.append(self.method == "eagle3")
193
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
194
195
196
197
        return hash_str

    @staticmethod
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
198
        initial_architecture = hf_config.architectures[0]
zhuwenwen's avatar
zhuwenwen committed
199
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
200
201
202
            hf_config.model_type = "deepseek_mtp"
        if hf_config.model_type == "deepseek_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
203
204
205
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
            )
206
207
208
209
210
211
212
        if hf_config.model_type in ("pangu_ultra_moe"):
            hf_config.model_type = "pangu_ultra_moe_mtp"
        if hf_config.model_type == "pangu_ultra_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
            )
213
214
215
216

        if hf_config.architectures[0] == "MiMoForCausalLM":
            hf_config.model_type = "mimo_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
217
218
219
220
221
222
223
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["MiMoMTPModel"],
                }
            )
224
225
226
227

        if hf_config.architectures[0] == "Glm4MoeForCausalLM":
            hf_config.model_type = "glm4_moe_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
228
229
230
231
232
233
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeMTPModel"],
                }
            )
234

235
236
237
238
239
240
241
242
243
244
245
        if hf_config.architectures[0] == "Glm4MoeLiteForCausalLM":
            hf_config.model_type = "glm4_moe_lite_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeLiteMTPModel"],
                }
            )

246
247
248
249
        if hf_config.model_type == "ernie4_5_moe":
            hf_config.model_type = "ernie_mtp"
        if hf_config.model_type == "ernie_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
250
251
252
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
            )
253
254
255
256
257

        if hf_config.model_type == "qwen3_next":
            hf_config.model_type = "qwen3_next_mtp"
        if hf_config.model_type == "qwen3_next_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
258
259
260
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
            )
Kyungmin Lee's avatar
Kyungmin Lee committed
261
262
263
264
265
266
267
268
269

        if hf_config.model_type == "exaone_moe":
            hf_config.model_type = "exaone_moe_mtp"
        if hf_config.model_type == "exaone_moe_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
            )

Rayyyyy's avatar
Rayyyyy committed
270
271
272
273
274
275
276
277
278
279
        if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
            is_moe = hf_config.model_type == "qwen3_5_moe"
            hf_config.model_type = "qwen3_5_mtp"
            n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
            hf_config.update(
                {
                    "n_predict": n_predict,
                    "architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
                }
            )
XuruiYang's avatar
XuruiYang committed
280
281
282
        if hf_config.model_type == "longcat_flash":
            hf_config.model_type = "longcat_flash_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
283
284
285
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
            )
286

csy0225's avatar
csy0225 committed
287
288
289
290
291
        if hf_config.model_type == "step3p5":
            hf_config.model_type = "step3p5_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
            hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})

292
293
        if initial_architecture == "MistralLarge3ForCausalLM":
            hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
294
295
296
297
298
299
300
301
302
303
304
305

        return hf_config

    def __post_init__(self):
        # Note: "method" is a new parameter that helps to extend the
        # configuration of non-model-based proposers, and the "model" parameter
        # will be used to set the draft model, eagle head, or additional weight
        # when needed. If users do not specify "method", the speculative method
        # will be detected automatically if possible. If the speculative method
        # can not be detected, it will be considered as the "draft_model" by
        # default.

306
        if self.method in get_args(MTPModelTypes) and self.method != "mtp":
307
308
309
            logger.warning(
                "method `%s` is deprecated and replaced with mtp.", self.method
            )
310
311
            self.method = "mtp"

312
        if self.model is None and self.num_speculative_tokens is not None:
313
            if self.method == "mtp":
314
315
                if self.target_model_config is None:
                    raise ValueError("target_model_config must be present for mtp")
316
                if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
317
318
319
                    # FIXME(luccafong): cudgraph with v32 MTP is not supported,
                    # remove this when the issue is fixed.
                    self.enforce_eager = True
320
321
322
323
324
325
326
327
                # use the draft model from the same model:
                self.model = self.target_model_config.model
                # Align the quantization of draft model for cases such as
                # --quantization fp8 with a bf16 checkpoint.
                if not self.quantization:
                    self.quantization = self.target_model_config.quantization
            elif self.method in ("ngram", "[ngram]"):
                self.model = "ngram"
328
329
            elif self.method == "suffix":
                self.model = "suffix"
330
            else:
331
                raise ValueError(
332
333
                    "num_speculative_tokens was provided but without speculative model."
                )
334
335
336

        # Automatically configure the method for ngram when "model" is used
        # instead of "method"
337
338
339
        if self.method is None and (
            self.model is not None and self.model in ("ngram", "[ngram]")
        ):
340
341
342
343
344
345
            self.method = "ngram"

        if self.method in ("ngram", "[ngram]"):
            # Unified to "ngram" internally
            self.method = "ngram"
            # Set default values if not provided
346
            if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
347
348
349
350
                # TODO(woosuk): Tune these values. They are arbitrarily chosen.
                self.prompt_lookup_min = 5
                self.prompt_lookup_max = 5
            elif self.prompt_lookup_min is None:
351
352
353
354
355
                if self.prompt_lookup_max is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
356
357
                self.prompt_lookup_min = self.prompt_lookup_max
            elif self.prompt_lookup_max is None:
358
359
360
361
362
                if self.prompt_lookup_min is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
363
364
365
366
367
368
                self.prompt_lookup_max = self.prompt_lookup_min

            # Validate values
            if self.prompt_lookup_min > self.prompt_lookup_max:
                raise ValueError(
                    f"prompt_lookup_min={self.prompt_lookup_min} must "
369
370
                    f"be <= prompt_lookup_max={self.prompt_lookup_max}"
                )
371
372
373
374
375
376

            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            self.draft_model_config = self.target_model_config
            self.draft_parallel_config = self.target_parallel_config
377
378
        elif self.method == "suffix":
            self._validate_suffix_decoding()
379
380
381
382
383
384
385
386
387
388
        else:
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if self.model is not None:
                self.draft_model_config = ModelConfig(
                    model=self.model,
                    runner="draft",
                    tokenizer=self.target_model_config.tokenizer,
                    tokenizer_mode=self.target_model_config.tokenizer_mode,
389
390
391
                    trust_remote_code=self.target_model_config.trust_remote_code,
                    allowed_local_media_path=self.target_model_config.allowed_local_media_path,
                    allowed_media_domains=self.target_model_config.allowed_media_domains,
392
393
394
395
                    dtype=self.target_model_config.dtype,
                    seed=self.target_model_config.seed,
                    revision=self.revision,
                    code_revision=self.code_revision,
396
397
                    tokenizer_revision=self.target_model_config.tokenizer_revision,
                    spec_target_max_model_len=self.target_model_config.max_model_len,
398
                    quantization=self.quantization,
zhuwenwen's avatar
zhuwenwen committed
399
                    enforce_eager=self.target_model_config.enforce_eager,
400
401
                    max_logprobs=self.target_model_config.max_logprobs,
                    hf_overrides=SpeculativeConfig.hf_config_override,
402
                    config_format=self.target_model_config.config_format,
403
404
405
                )

                # Automatically detect the method
406
                if self.method in ("eagle", "eagle3"):
407
408
409
410
411
412
413
414
415
416
417
                    pass
                # examples:
                # yuhuili/EAGLE-LLaMA3-Instruct-8B
                # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
                # AngelSlim/Qwen3-8B_eagle3
                elif "eagle-" in self.draft_model_config.model.lower():
                    self.method = "eagle"
                elif "eagle3" in self.draft_model_config.model.lower():
                    self.method = "eagle3"
                elif self.draft_model_config.hf_config.model_type == "medusa":
                    self.method = "medusa"
418
                elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
419
                    self.method = "mlp_speculator"
420
421
422
                elif self.draft_model_config.hf_config.model_type in get_args(
                    MTPModelTypes
                ):
423
                    self.method = "mtp"
luopl's avatar
luopl committed
424
425
426
427
428
                    # if self.num_speculative_tokens > 1:
                    if (
                        self.enable_multi_layers_mtp is False
                        and self.num_speculative_tokens > 1
                    ):
429
                        logger.warning(
430
431
432
433
434
435
436
                            "Enabling num_speculative_tokens > 1 will run"
                            "multiple times of forward on same MTP layer"
                            ",which may result in lower acceptance rate"
                        )
                elif self.draft_model_config.hf_config.model_type in (
                    "longcat_flash_mtp"
                ):
XuruiYang's avatar
XuruiYang committed
437
438
439
                    self.method = "longcat_flash_mtp"
                    if self.num_speculative_tokens > 1:
                        logger.warning(
440
441
442
443
                            "LongCat MTP models only have "
                            "one layer. Might need some code changes "
                            "to support multiple layers."
                        )
444
445
                elif self.method == "draft_model":
                    pass
446
447
                else:
                    raise NotImplementedError(
448
                        f"Unsupported speculative method: '{self.method}'"
449
                    )
450
451
452

                # Replace hf_config for EAGLE draft_model
                if self.method in ("eagle", "eagle3"):
453
454
                    from vllm.transformers_utils.configs import SpeculatorsConfig
                    from vllm.transformers_utils.configs.eagle import EAGLEConfig
455

456
457
458
459
                    if isinstance(
                        self.draft_model_config.hf_config,
                        (EAGLEConfig, SpeculatorsConfig),
                    ):
460
461
462
463
464
                        pass
                    else:
                        eagle_config = EAGLEConfig(
                            self.draft_model_config.hf_config,
                            method=self.method,
465
466
                            model_type="eagle",
                        )
467
468
                        # EAGLEConfig primarily updates architectures, so update
                        # all architectures-related fields in draft_model_config
469
                        self.draft_model_config.hf_config = eagle_config
470
471
472
                        self.draft_model_config.hf_text_config = get_hf_text_config(
                            self.draft_model_config.hf_config
                        )
473
474
475
                        self.draft_model_config.model_arch_config = (
                            self.draft_model_config.get_model_arch_config()
                        )
476
477
478
479
480
481
482
483
                        model_info, arch = (
                            self.draft_model_config.registry.inspect_model_cls(
                                self.draft_model_config.architectures,
                                self.draft_model_config,
                            )
                        )
                        self.draft_model_config._model_info = model_info
                        self.draft_model_config._architecture = arch
484

485
486
487
488
489
490
                if self.num_speculative_tokens is not None and hasattr(
                    self.draft_model_config.hf_config, "num_lookahead_tokens"
                ):
                    self.draft_model_config.hf_config.num_lookahead_tokens = (
                        self.num_speculative_tokens
                    )
491

492
493
494
                n_predict = getattr(
                    self.draft_model_config.hf_config, "n_predict", None
                )
495
496
497
498
                if n_predict is not None:
                    if self.num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        self.num_speculative_tokens = n_predict
luopl's avatar
luopl committed
499
500
501
502
503
504
505
506
507
508
509
                    elif (
                        self.method == "mtp"
                        and self.enable_multi_layers_mtp
                        and self.num_speculative_tokens > n_predict
                    ):
                        logger.warning_once(
                            "For multi_layer_eagle, num_speculative_tokens "
                            "is greater than the layer_num, adjusting to "
                            "layer_num"
                        )
                        self.num_speculative_tokens = n_predict
510
511
512
513
                    elif (
                        self.num_speculative_tokens > n_predict
                        and self.num_speculative_tokens % n_predict != 0
                    ):
514
515
516
                        # Ensure divisibility for MTP module reuse.
                        raise ValueError(
                            f"num_speculative_tokens:{self.num_speculative_tokens}"
517
518
                            f" must be divisible by {n_predict=}"
                        )
519
520
521

                if self.speculative_token_tree is None:
                    # Generate chain of tokens.
522
523
524
                    self.speculative_token_tree = str(
                        [(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
                    )
525
526
                else:
                    # Sort the token tree breadth-first.
527
                    tree_choices = ast.literal_eval(self.speculative_token_tree)
528
                    self.speculative_token_tree = str(
529
530
                        sorted(tree_choices, key=lambda t: (len(t), t))
                    )
531

532
                self.draft_tensor_parallel_size = (
533
534
535
                    SpeculativeConfig._verify_and_get_draft_tp(
                        self.target_parallel_config,
                        self.draft_tensor_parallel_size,
536
537
                        self.draft_model_config.hf_config,
                    )
538
539
540
541
542
543
544
                )

                self.draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        self.max_model_len,
                        self.draft_model_config.max_model_len,
                        self.target_model_config.max_model_len,
545
546
                    )
                )
547
548
549

                self.draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
550
551
552
                        self.target_parallel_config, self.draft_tensor_parallel_size
                    )
                )
553
        return self
554

555
556
557
558
    def _validate_suffix_decoding(self):
        if not has_arctic_inference():
            raise ImportError(
                "Arctic Inference is required for suffix decoding. "
559
                "Install via `pip install arctic-inference==0.1.1`."
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
            )
        if self.num_speculative_tokens is None:
            # Suffix decoding decides the actual number of speculative tokens
            # dynamically and treats num_speculative_tokens as a maximum limit.
            self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
            logger.warning(
                "Defaulted num_speculative_tokens to %s for suffix decoding.",
                self.num_speculative_tokens,
            )
        # Validate values
        if self.suffix_decoding_max_tree_depth < 1:
            raise ValueError(
                f"suffix_decoding_max_tree_depth="
                f"{self.suffix_decoding_max_tree_depth} must be >= 1"
            )
        if self.suffix_decoding_max_cached_requests < 0:
            raise ValueError(
                f"suffix_decoding_max_cached_requests="
                f"{self.suffix_decoding_max_cached_requests} must be >= 0"
            )
        if self.suffix_decoding_max_spec_factor < 0:
            raise ValueError(
                f"suffix_decoding_max_spec_factor="
                f"{self.suffix_decoding_max_spec_factor} must be >= 0"
            )
        if not 0 <= self.suffix_decoding_min_token_prob <= 1:
            raise ValueError(
                f"suffix_decoding_min_token_prob="
                f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
            )
590
591
592

    @staticmethod
    def _maybe_override_draft_max_model_len(
593
        speculative_max_model_len: int | None,
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:
            if speculative_max_model_len > draft_max_model_len:
611
612
613
614
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {draft_max_model_len=}"
                )
615
616

            if speculative_max_model_len > target_max_model_len:
617
618
619
620
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {target_max_model_len=}"
                )
621
622
623
624
625
626
627
628
629
630

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

    @staticmethod
    def _verify_and_get_draft_tp(
631
        target_parallel_config: ParallelConfig,
632
        speculative_draft_tensor_parallel_size: int | None,
633
634
        draft_hf_config: PretrainedConfig,
    ) -> int:
635
636
637
638
639
640
641
642
643
644
645
646
647
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
        """
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
        if speculative_draft_tensor_parallel_size is None:
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "%s cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1",
648
649
                        draft_hf_config.model_type,
                    )
650
            else:
651
                speculative_draft_tensor_parallel_size = (
652
                    target_parallel_config.tensor_parallel_size
653
                )
654
        elif speculative_draft_tensor_parallel_size not in (
655
656
657
            1,
            target_parallel_config.tensor_parallel_size,
        ):
658
659
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be "
660
661
                f"other value than 1 or target model tensor_parallel_size"
            )
662
663
664
665
666
667
668
669
670
671
672
673
        return speculative_draft_tensor_parallel_size

    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
        draft_parallel_config = ParallelConfig(
674
            pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
675
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
676
677
678
679
            distributed_executor_backend=target_parallel_config.distributed_executor_backend,
            max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
            ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
680
681
682
683
684
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

685
    @model_validator(mode="after")
686
    def _verify_args(self) -> Self:
687
688
689
690
691
692
        if self.tensor_parallel_size is not None:
            raise ValueError(
                "'tensor_parallel_size' is not a valid argument in the "
                "speculative_config. Please pass 'draft_tensor_parallel_size' instead."
            )

693
694
695
696
        if self.num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative model unless the draft model config contains an "
697
698
                "n_predict parameter."
            )
699
700

        if self.num_speculative_tokens <= 0:
701
702
703
704
            raise ValueError(
                "Expected num_speculative_tokens to be greater "
                f"than zero ({self.num_speculative_tokens})."
            )
705
706
707

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
708
709
                self.draft_parallel_config
            )
710

711
712
713
714
715
716
        if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2:
            raise ValueError(
                "Expect the batch size threshold of disabling "
                "speculative decoding is > 1, but got "
                f"{self.disable_by_batch_size=}"
            )
717

718
        eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
719
720
721
722
723
724
725
726
        if (
            self.method == "eagle3"
            and self.target_model_config
            and not any(
                supported_model in self.target_model_config.hf_text_config.model_type
                for supported_model in eagle3_target_supported
            )
        ):
727
728
            raise ValueError(
                f"Eagle3 is only supported for {eagle3_target_supported} models. "  # noqa: E501
729
730
                f"Got {self.target_model_config.hf_text_config.model_type=}"
            )
731
        self.verify_equal_vocab_size_if_draft_model()
732
733
        return self

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
    def verify_equal_vocab_size_if_draft_model(self):
        if (
            self.method == "draft_model"
            and self.target_model_config is not None
            and self.draft_model_config is not None
        ):
            target_vocab_size = self.target_model_config.get_vocab_size()
            draft_vocab_size = self.draft_model_config.get_vocab_size()
            if target_vocab_size != draft_vocab_size:
                raise ValueError(
                    f"Target and draft model should have the same vocabulary size. "
                    f"Target model vocab_size={target_vocab_size}. "
                    f"Draft model vocab_size={draft_vocab_size}. "
                    f"Using models with different tokenizers can cause out-of-bounds "
                    f"errors during speculative decoding."
                )

luopl's avatar
luopl committed
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
    @property
    def max_num_new_slots_for_drafting(self) -> int:
        """
        Calculate the maximum number of new slots that might be added to the batch
        when drafting.
        """
        slots_per_req = 0  # for serial non-draft-model methods, no change needed
        if self.parallel_drafting:
            # For parallel drafting, we need one new slot per 'masked' token
            slots_per_req = self.num_speculative_tokens - 1
        if self.uses_draft_model():
            # For draft model-based speculation, we need one new slot per request
            # Since we do not slice the draft tokens
            slots_per_req += 1
        return slots_per_req

767
    def use_eagle(self) -> bool:
768
        return self.method in ("eagle", "eagle3", "mtp")
769

770
771
772
    def uses_draft_model(self) -> bool:
        return self.method == "draft_model"

luopl's avatar
luopl committed
773
774
775
    def uses_extract_hidden_states(self) -> bool:
        return self.method == "extract_hidden_states"

776
777
    def __repr__(self) -> str:
        method = self.method
778
        model = None if method in ("ngram", "suffix") else self.draft_model_config.model
779
780
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"