speculative.py 26 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
import hashlib
6
from typing import TYPE_CHECKING, Any, Literal
7

8
from pydantic import Field, SkipValidation, model_validator
9
10
11
12
13
14
15
from pydantic.dataclasses import dataclass
from typing_extensions import Self

import vllm.envs as envs
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
16
from vllm.utils.import_utils import LazyLoader
17
18
19
20
21
22
23
24
25
26

if TYPE_CHECKING:
    from transformers import PretrainedConfig

    import vllm.model_executor.layers.quantization as me_quant
    from vllm.config import ModelConfig
else:
    PretrainedConfig = Any
    ModelConfig = Any

27
28
29
    me_quant = LazyLoader(
        "model_executor", globals(), "vllm.model_executor.layers.quantization"
    )
30
31
32

logger = init_logger(__name__)

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
SpeculativeMethod = Literal[
    "ngram",
    "eagle",
    "eagle3",
    "medusa",
    "mlp_speculator",
    "draft_model",
    "deepseek_mtp",
    "ernie_mtp",
    "qwen3_next_mtp",
    "mimo_mtp",
    "longcat_flash_mtp",
    "mtp",
]
MTP_MODEL_TYPES = (
    "deepseek_mtp",
    "mimo_mtp",
    "glm4_moe_mtp",
    "ernie_mtp",
    "qwen3_next_mtp",
    "longcat_flash_mtp",
)
55
56
57
58
59
60


@config
@dataclass
class SpeculativeConfig:
    """Configuration for speculative decoding."""
61

62
    enforce_eager: bool | None = None
63
    """Override the default enforce_eager from model_config"""
64
    # General speculative decoding control
65
    num_speculative_tokens: int = Field(default=None, gt=0)
66
67
    """The number of speculative tokens, if provided. It will default to the
    number in the draft model config if present, otherwise, it is required."""
68
    model: str | None = None
69
70
    """The name of the draft model, eagle head, or additional weights, if
    provided."""
71
    method: SpeculativeMethod | None = None
72
73
74
75
76
77
78
    """The name of the speculative method to use. If users provide and set the
    `model` param, the speculative method type will be detected automatically
    if possible, if `model` param is not provided, the method name must be
    provided.

    If using `ngram` method, the related configuration `prompt_lookup_max` and
    `prompt_lookup_min` should be considered."""
79
    draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
80
81
82
83
84
85
86
87
    """The degree of the tensor parallelism for the draft model. Can only be 1
    or the same as the target model's tensor parallel size."""
    disable_logprobs: bool = True
    """If set to True, token log probabilities are not returned during
    speculative decoding. If set to False, token log probabilities are returned
    according to the log probability settings in SamplingParams."""

    # Draft model configuration
88
    quantization: me_quant.QuantizationMethods | None = None
89
90
91
    """Quantization method that was used to quantize the draft model weights.
    If `None`, we assume the model weights are not quantized. Note that it only
    takes effect when using the draft model-based speculative method."""
92
    max_model_len: int | None = Field(default=None, ge=1)
93
94
    """The maximum model length of the draft model. Used when testing the
    ability to skip speculation for some sequences."""
95
    revision: str | None = None
96
97
98
    """The specific model version to use for the draft model. It can be a
    branch name, a tag name, or a commit id. If unspecified, will use the
    default version."""
99
    code_revision: str | None = None
100
101
102
103
104
    """The specific revision to use for the draft model code on Hugging Face
    Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
    will use the default version."""

    # Advanced control
105
    disable_by_batch_size: int | None = Field(default=None, ge=2)
106
107
    """Disable speculative decoding for new incoming requests when the number
    of enqueued requests is larger than this value, if provided."""
108
109
110
111
112
    disable_padded_drafter_batch: bool = False
    """Disable input padding for speculative decoding. If set to True,
    speculative input batches can contain sequences of different lengths,
    which may only be supported by certain attention backends. This currently
    only affects the EAGLE method of speculation."""
113
114

    # Ngram proposer configuration
115
    prompt_lookup_max: int | None = Field(default=None, ge=1)
116
117
    """Maximum size of ngram token window when using Ngram proposer, required
    when method is set to ngram."""
118
    prompt_lookup_min: int | None = Field(default=None, ge=1)
119
120
121
    """Minimum size of ngram token window when using Ngram proposer, if
    provided. Defaults to 1."""

122
    speculative_token_tree: str | None = None
123
124
125
126
127
    """Specifies the tree structure for speculative token generation.
    """
    # required configuration params passed from engine
    target_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the target model."""
128
    target_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
129
130
131
132
133
134
135
136
137
138
139
    """The parallel configuration for the target model."""
    enable_chunked_prefill: SkipValidation[bool] = None  # type: ignore
    """Whether vLLM is configured to use chunked prefill or not. Used for
    raising an error since it's not yet compatible with speculative decode."""
    disable_log_stats: SkipValidation[bool] = None  # type: ignore
    """Whether to disable the periodic printing of stage times in speculative
    decoding."""

    # params generated in the post-init stage
    draft_model_config: SkipValidation[ModelConfig] = None  # type: ignore
    """The configuration of the draft model initialized internal."""
140
    draft_parallel_config: SkipValidation[ParallelConfig] = None  # type: ignore
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    """The parallel configuration for the draft model initialized internal."""

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        # Eagle3 affects the computation graph because it returns intermediate
        # hidden states in addition to the final hidden state.
        factors.append(self.method == "eagle3")
159
        hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
160
161
162
163
        return hash_str

    @staticmethod
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
164
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
165
166
167
            hf_config.model_type = "deepseek_mtp"
        if hf_config.model_type == "deepseek_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
168
169
170
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
            )
171
172
173
174

        if hf_config.architectures[0] == "MiMoForCausalLM":
            hf_config.model_type = "mimo_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
175
176
177
178
179
180
181
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["MiMoMTPModel"],
                }
            )
182
183
184
185

        if hf_config.architectures[0] == "Glm4MoeForCausalLM":
            hf_config.model_type = "glm4_moe_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
186
187
188
189
190
191
192
            hf_config.update(
                {
                    "num_hidden_layers": 0,
                    "n_predict": n_predict,
                    "architectures": ["Glm4MoeMTPModel"],
                }
            )
193
194
195
196
197

        if hf_config.model_type == "ernie4_5_moe":
            hf_config.model_type = "ernie_mtp"
        if hf_config.model_type == "ernie_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
198
199
200
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
            )
201
202
203
204
205

        if hf_config.model_type == "qwen3_next":
            hf_config.model_type = "qwen3_next_mtp"
        if hf_config.model_type == "qwen3_next_mtp":
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
206
207
208
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
            )
XuruiYang's avatar
XuruiYang committed
209
210
211
        if hf_config.model_type == "longcat_flash":
            hf_config.model_type = "longcat_flash_mtp"
            n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
212
213
214
            hf_config.update(
                {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
            )
215
216
217
218
219
220
221
222
223
224
225
226

        return hf_config

    def __post_init__(self):
        # Note: "method" is a new parameter that helps to extend the
        # configuration of non-model-based proposers, and the "model" parameter
        # will be used to set the draft model, eagle head, or additional weight
        # when needed. If users do not specify "method", the speculative method
        # will be detected automatically if possible. If the speculative method
        # can not be detected, it will be considered as the "draft_model" by
        # default.

227
        if self.method in MTP_MODEL_TYPES:
228
229
230
            logger.warning(
                "method `%s` is deprecated and replaced with mtp.", self.method
            )
231
232
            self.method = "mtp"

233
        if self.model is None and self.num_speculative_tokens is not None:
234
            if self.method == "mtp":
235
236
                if self.target_model_config is None:
                    raise ValueError("target_model_config must be present for mtp")
237
                if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
238
239
240
                    # FIXME(luccafong): cudgraph with v32 MTP is not supported,
                    # remove this when the issue is fixed.
                    self.enforce_eager = True
241
242
243
244
245
246
247
248
249
                # use the draft model from the same model:
                self.model = self.target_model_config.model
                # Align the quantization of draft model for cases such as
                # --quantization fp8 with a bf16 checkpoint.
                if not self.quantization:
                    self.quantization = self.target_model_config.quantization
            elif self.method in ("ngram", "[ngram]"):
                self.model = "ngram"
            else:
250
                raise ValueError(
251
252
                    "num_speculative_tokens was provided but without speculative model."
                )
253
254
255

        # Automatically configure the method for ngram when "model" is used
        # instead of "method"
256
257
258
        if self.method is None and (
            self.model is not None and self.model in ("ngram", "[ngram]")
        ):
259
260
261
262
263
264
            self.method = "ngram"

        if self.method in ("ngram", "[ngram]"):
            # Unified to "ngram" internally
            self.method = "ngram"
            # Set default values if not provided
265
            if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
266
267
268
269
                # TODO(woosuk): Tune these values. They are arbitrarily chosen.
                self.prompt_lookup_min = 5
                self.prompt_lookup_max = 5
            elif self.prompt_lookup_min is None:
270
271
272
273
274
                if self.prompt_lookup_max is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
275
276
                self.prompt_lookup_min = self.prompt_lookup_max
            elif self.prompt_lookup_max is None:
277
278
279
280
281
                if self.prompt_lookup_min is None:
                    raise ValueError(
                        "Either prompt_lookup_max or prompt_lookup_min must be "
                        "provided when using the ngram method."
                    )
282
283
284
285
286
287
                self.prompt_lookup_max = self.prompt_lookup_min

            # Validate values
            if self.prompt_lookup_min > self.prompt_lookup_max:
                raise ValueError(
                    f"prompt_lookup_min={self.prompt_lookup_min} must "
288
289
                    f"be <= prompt_lookup_max={self.prompt_lookup_max}"
                )
290
291
292
293
294
295
296
297
298
299
300
301
302
303

            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            self.draft_model_config = self.target_model_config
            self.draft_parallel_config = self.target_parallel_config
        else:
            self.prompt_lookup_max = 0
            self.prompt_lookup_min = 0

            if self.model is not None:
                # TODO: Move this import to the top once `ModelConfig`
                # lives in `vllm.config.model`.
                from vllm.config import ModelConfig
304

305
306
307
308
309
                self.draft_model_config = ModelConfig(
                    model=self.model,
                    runner="draft",
                    tokenizer=self.target_model_config.tokenizer,
                    tokenizer_mode=self.target_model_config.tokenizer_mode,
310
311
312
                    trust_remote_code=self.target_model_config.trust_remote_code,
                    allowed_local_media_path=self.target_model_config.allowed_local_media_path,
                    allowed_media_domains=self.target_model_config.allowed_media_domains,
313
314
315
316
                    dtype=self.target_model_config.dtype,
                    seed=self.target_model_config.seed,
                    revision=self.revision,
                    code_revision=self.code_revision,
317
318
                    tokenizer_revision=self.target_model_config.tokenizer_revision,
                    spec_target_max_model_len=self.target_model_config.max_model_len,
319
320
321
322
323
324
325
                    quantization=self.quantization,
                    enforce_eager=self.target_model_config.enforce_eager,
                    max_logprobs=self.target_model_config.max_logprobs,
                    hf_overrides=SpeculativeConfig.hf_config_override,
                )

                # Automatically detect the method
326
                if self.method in ("eagle", "eagle3"):
327
328
329
330
331
332
333
334
335
336
337
                    pass
                # examples:
                # yuhuili/EAGLE-LLaMA3-Instruct-8B
                # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
                # AngelSlim/Qwen3-8B_eagle3
                elif "eagle-" in self.draft_model_config.model.lower():
                    self.method = "eagle"
                elif "eagle3" in self.draft_model_config.model.lower():
                    self.method = "eagle3"
                elif self.draft_model_config.hf_config.model_type == "medusa":
                    self.method = "medusa"
338
                elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
339
                    self.method = "mlp_speculator"
340
                elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
341
                    self.method = "mtp"
342
343
                    if self.num_speculative_tokens > 1:
                        logger.warning(
344
345
346
347
348
349
350
                            "Enabling num_speculative_tokens > 1 will run"
                            "multiple times of forward on same MTP layer"
                            ",which may result in lower acceptance rate"
                        )
                elif self.draft_model_config.hf_config.model_type in (
                    "longcat_flash_mtp"
                ):
XuruiYang's avatar
XuruiYang committed
351
352
353
                    self.method = "longcat_flash_mtp"
                    if self.num_speculative_tokens > 1:
                        logger.warning(
354
355
356
357
                            "LongCat MTP models only have "
                            "one layer. Might need some code changes "
                            "to support multiple layers."
                        )
358
359
360
361
362
363
                else:
                    self.method = "draft_model"
                    raise NotImplementedError(
                        "Speculative decoding with draft model is not "
                        "supported yet. Please consider using other "
                        "speculative decoding methods such as ngram, medusa, "
364
365
                        "eagle, or mtp."
                    )
366
367
368
369
370
371

                # Replace hf_config for EAGLE draft_model
                if self.method in ("eagle", "eagle3"):
                    if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
                        raise ValueError(
                            "Chunked prefill and EAGLE are not compatible "
372
373
                            "when using V0."
                        )
374

375
376
                    from vllm.transformers_utils.configs import SpeculatorsConfig
                    from vllm.transformers_utils.configs.eagle import EAGLEConfig
377

378
379
380
381
                    if isinstance(
                        self.draft_model_config.hf_config,
                        (EAGLEConfig, SpeculatorsConfig),
                    ):
382
383
384
385
386
                        pass
                    else:
                        eagle_config = EAGLEConfig(
                            self.draft_model_config.hf_config,
                            method=self.method,
387
388
                            model_type="eagle",
                        )
389
390
                        self.draft_model_config.hf_config = eagle_config

391
392
393
394
395
396
                if self.num_speculative_tokens is not None and hasattr(
                    self.draft_model_config.hf_config, "num_lookahead_tokens"
                ):
                    self.draft_model_config.hf_config.num_lookahead_tokens = (
                        self.num_speculative_tokens
                    )
397

398
399
400
                n_predict = getattr(
                    self.draft_model_config.hf_config, "n_predict", None
                )
401
402
403
404
                if n_predict is not None:
                    if self.num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        self.num_speculative_tokens = n_predict
405
406
407
408
                    elif (
                        self.num_speculative_tokens > n_predict
                        and self.num_speculative_tokens % n_predict != 0
                    ):
409
410
411
                        # Ensure divisibility for MTP module reuse.
                        raise ValueError(
                            f"num_speculative_tokens:{self.num_speculative_tokens}"
412
413
                            f" must be divisible by {n_predict=}"
                        )
414
415
416

                if self.speculative_token_tree is None:
                    # Generate chain of tokens.
417
418
419
                    self.speculative_token_tree = str(
                        [(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
                    )
420
421
                else:
                    # Sort the token tree breadth-first.
422
                    tree_choices = ast.literal_eval(self.speculative_token_tree)
423
                    self.speculative_token_tree = str(
424
425
                        sorted(tree_choices, key=lambda t: (len(t), t))
                    )
426

427
                self.draft_tensor_parallel_size = (
428
429
430
                    SpeculativeConfig._verify_and_get_draft_tp(
                        self.target_parallel_config,
                        self.draft_tensor_parallel_size,
431
432
                        self.draft_model_config.hf_config,
                    )
433
434
435
436
437
438
439
                )

                self.draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        self.max_model_len,
                        self.draft_model_config.max_model_len,
                        self.target_model_config.max_model_len,
440
441
                    )
                )
442
443
444

                self.draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
445
446
447
                        self.target_parallel_config, self.draft_tensor_parallel_size
                    )
                )
448
        return self
449
450
451

    @staticmethod
    def _maybe_override_draft_max_model_len(
452
        speculative_max_model_len: int | None,
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:
            if speculative_max_model_len > draft_max_model_len:
470
471
472
473
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {draft_max_model_len=}"
                )
474
475

            if speculative_max_model_len > target_max_model_len:
476
477
478
479
                raise ValueError(
                    f"{speculative_max_model_len=} cannot be "
                    f"larger than {target_max_model_len=}"
                )
480
481
482
483
484
485
486
487
488
489

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

    @staticmethod
    def _verify_and_get_draft_tp(
490
        target_parallel_config: ParallelConfig,
491
        speculative_draft_tensor_parallel_size: int | None,
492
493
        draft_hf_config: PretrainedConfig,
    ) -> int:
494
495
496
497
498
499
500
501
502
503
504
505
506
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
        """
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
        if speculative_draft_tensor_parallel_size is None:
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "%s cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1",
507
508
                        draft_hf_config.model_type,
                    )
509
            else:
510
                speculative_draft_tensor_parallel_size = (
511
                    target_parallel_config.tensor_parallel_size
512
                )
513
        elif speculative_draft_tensor_parallel_size not in (
514
515
516
            1,
            target_parallel_config.tensor_parallel_size,
        ):
517
518
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be "
519
520
                f"other value than 1 or target model tensor_parallel_size"
            )
521
522
523
524
525
526
527
528
529
530
531
532
        return speculative_draft_tensor_parallel_size

    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
        draft_parallel_config = ParallelConfig(
533
            pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
534
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
535
536
537
538
            distributed_executor_backend=target_parallel_config.distributed_executor_backend,
            max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
            ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
539
540
541
542
543
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

544
    @model_validator(mode="after")
545
546
547
548
549
    def _verify_args(self) -> Self:
        if self.num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative model unless the draft model config contains an "
550
551
                "n_predict parameter."
            )
552
553

        if self.num_speculative_tokens <= 0:
554
555
556
557
            raise ValueError(
                "Expected num_speculative_tokens to be greater "
                f"than zero ({self.num_speculative_tokens})."
            )
558
559
560

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
561
562
                self.draft_parallel_config
            )
563

564
565
566
567
568
569
        if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2:
            raise ValueError(
                "Expect the batch size threshold of disabling "
                "speculative decoding is > 1, but got "
                f"{self.disable_by_batch_size=}"
            )
570

571
        eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
572
573
574
575
576
577
578
579
        if (
            self.method == "eagle3"
            and self.target_model_config
            and not any(
                supported_model in self.target_model_config.hf_text_config.model_type
                for supported_model in eagle3_target_supported
            )
        ):
580
581
            raise ValueError(
                f"Eagle3 is only supported for {eagle3_target_supported} models. "  # noqa: E501
582
583
                f"Got {self.target_model_config.hf_text_config.model_type=}"
            )
584
585
586
587
588
589
590
591
592
593
594
595
596
597

        return self

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def use_eagle(self) -> bool:
598
        return self.method in ("eagle", "eagle3", "mtp")
599
600
601
602
603
604

    def __repr__(self) -> str:
        method = self.method
        model = None if method == "ngram" else self.draft_model_config.model
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"