Unverified Commit cd764301 authored by Andrii Skliar's avatar Andrii Skliar Committed by GitHub
Browse files

[Feature] Support per-draft-model MoE backend via `--speculative-config` (#37880)


Signed-off-by: default avatarAndrii Skliar <askliar@nvidia.com>
Signed-off-by: default avatar[Andrii Skliar] <askliar@nvidia.com>
Co-authored-by: default avatarAndrii Skliar <askliar@nvidia.com>
parent a1a25664
...@@ -20,12 +20,11 @@ from vllm import LLM, SamplingParams ...@@ -20,12 +20,11 @@ from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config import VllmConfig from vllm.config import VllmConfig, replace
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
MTP_SIMILARITY_RATE = 0.8 MTP_SIMILARITY_RATE = 0.8
...@@ -919,13 +918,104 @@ def test_draft_model_engine_args_tensor_parallelism(): ...@@ -919,13 +918,104 @@ def test_draft_model_engine_args_tensor_parallelism():
"draft_tensor_parallel_size": 1, # <<< valid arg name "draft_tensor_parallel_size": 1, # <<< valid arg name
}, },
) )
tgt_vllm_config: VllmConfig = engine_args.create_engine_config() target_config: VllmConfig = engine_args.create_engine_config()
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2 assert target_config.parallel_config.tensor_parallel_size == 2
assert tgt_vllm_config.quant_config.get_name() == "fp8" assert target_config.quant_config.get_name() == "fp8"
speculative_config = target_config.speculative_config
draft_config: VllmConfig = replace(
target_config,
quant_config=None,
parallel_config=replace(
speculative_config.draft_parallel_config,
rank=target_config.parallel_config.rank,
),
model_config=speculative_config.draft_model_config,
)
assert draft_config.parallel_config.tensor_parallel_size == 1
assert draft_config.quant_config is None
def _apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig:
"""Replicate SpecDecodeBaseProposer._create_draft_vllm_config logic
so we can test it without instantiating a full proposer."""
spec_cfg = vllm_config.speculative_config
if spec_cfg.moe_backend is not None:
return replace(
vllm_config,
kernel_config=replace(
vllm_config.kernel_config,
moe_backend=spec_cfg.moe_backend,
),
)
return vllm_config
def test_draft_model_moe_backend_override():
"""When moe_backend is set in speculative_config, the draft VllmConfig
should use it while the target keeps its own setting."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
moe_backend="flashinfer_trtllm",
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
"moe_backend": "triton",
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
assert tgt_config.speculative_config.moe_backend == "triton"
draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "triton"
# Target config must be unaffected.
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
def test_draft_model_moe_backend_inherits_target():
"""When moe_backend is not set in speculative_config, the draft should
inherit the target's moe_backend."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
moe_backend="flashinfer_cutlass",
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "flashinfer_cutlass"
assert tgt_config.speculative_config.moe_backend is None
draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "flashinfer_cutlass"
assert draft_config is tgt_config
def test_draft_model_moe_backend_default_auto():
"""When neither target nor draft set moe_backend explicitly, both should
default to 'auto'."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "auto"
assert tgt_config.speculative_config.moe_backend is None
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config) draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1 assert draft_config.kernel_config.moe_backend == "auto"
assert draft_vllm_config.quant_config is None assert draft_config is tgt_config
def test_draft_model_engine_args_rejects_invalid_tp_argname(): def test_draft_model_engine_args_rejects_invalid_tp_argname():
......
...@@ -9,6 +9,7 @@ from pydantic import Field, SkipValidation, model_validator ...@@ -9,6 +9,7 @@ from pydantic import Field, SkipValidation, model_validator
from typing_extensions import Self from typing_extensions import Self
from vllm.config import LoadConfig from vllm.config import LoadConfig
from vllm.config.kernel import MoEBackend
from vllm.config.model import ModelConfig from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config from vllm.config.utils import config
...@@ -93,6 +94,11 @@ class SpeculativeConfig: ...@@ -93,6 +94,11 @@ class SpeculativeConfig:
"""Quantization method that was used to quantize the draft model weights. """Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method.""" takes effect when using the draft model-based speculative method."""
moe_backend: MoEBackend | None = None
"""MoE backend to use for the draft model. When `None`, the draft model
inherits the target model's `--moe-backend` setting. Useful when the
drafter and generator require different MoE kernels (e.g. quantized
generator with unquantized drafter)."""
max_model_len: int | None = Field(default=None, ge=1) max_model_len: int | None = Field(default=None, ge=1)
"""The maximum model length of the draft model. Used when testing the """The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences.""" ability to skip speculation for some sequences."""
......
...@@ -6,10 +6,10 @@ import torch.nn as nn ...@@ -6,10 +6,10 @@ import torch.nn as nn
from typing_extensions import override from typing_extensions import override
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import replace
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -50,16 +50,29 @@ class DraftModelProposer(SpecDecodeBaseProposer): ...@@ -50,16 +50,29 @@ class DraftModelProposer(SpecDecodeBaseProposer):
"Please pass 'draft_tensor_parallel_size' in the speculative_config." "Please pass 'draft_tensor_parallel_size' in the speculative_config."
) )
@override
def _create_draft_vllm_config(self) -> VllmConfig:
base = super()._create_draft_vllm_config()
spec = self.speculative_config
return replace(
base,
quant_config=None,
parallel_config=replace(
spec.draft_parallel_config,
rank=self.vllm_config.parallel_config.rank,
),
model_config=spec.draft_model_config,
)
@override @override
def _get_model(self) -> nn.Module: def _get_model(self) -> nn.Module:
# Draft models may be quantized or on different parallelism,
# so we load them with a modified vllm config
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config) draft_vllm_config = self._create_draft_vllm_config()
with set_model_tag("draft_model"): with set_model_tag("draft_model"):
model = get_model( model = get_model(
vllm_config=temp_vllm_config, vllm_config=draft_vllm_config,
prefix="draft_model", prefix="draft_model",
) )
return model return model
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
from dataclasses import replace
from importlib.util import find_spec from importlib.util import find_spec
from typing import cast from typing import cast
...@@ -13,6 +12,7 @@ from vllm.config import ( ...@@ -13,6 +12,7 @@ from vllm.config import (
CUDAGraphMode, CUDAGraphMode,
VllmConfig, VllmConfig,
get_layers_from_vllm_config, get_layers_from_vllm_config,
replace,
) )
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
...@@ -1213,6 +1213,21 @@ class SpecDecodeBaseProposer: ...@@ -1213,6 +1213,21 @@ class SpecDecodeBaseProposer:
model = model.module model = model.module
return model.__class__.__name__ return model.__class__.__name__
def _create_draft_vllm_config(self) -> VllmConfig:
"""Return a VllmConfig with kernel-level overrides for the proposer.
Subclasses may override to apply additional config changes.
"""
spec_cfg = self.speculative_config
if spec_cfg.moe_backend is not None:
return replace(
self.vllm_config,
kernel_config=replace(
self.vllm_config.kernel_config,
moe_backend=spec_cfg.moe_backend,
),
)
return self.vllm_config
def _get_model(self) -> nn.Module: def _get_model(self) -> nn.Module:
""" """
Default method to call get_model(). Can be overridden by subclasses which Default method to call get_model(). Can be overridden by subclasses which
...@@ -1220,9 +1235,10 @@ class SpecDecodeBaseProposer: ...@@ -1220,9 +1235,10 @@ class SpecDecodeBaseProposer:
""" """
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
draft_vllm_config = self._create_draft_vllm_config()
with set_model_tag("eagle_head"): with set_model_tag("eagle_head"):
model = get_model( model = get_model(
vllm_config=self.vllm_config, vllm_config=draft_vllm_config,
model_config=self.speculative_config.draft_model_config, model_config=self.speculative_config.draft_model_config,
load_config=self.speculative_config.draft_load_config, load_config=self.speculative_config.draft_load_config,
) )
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch import torch
from vllm.config import VllmConfig, replace
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
...@@ -258,30 +257,6 @@ def compute_new_slot_mapping( ...@@ -258,30 +257,6 @@ def compute_new_slot_mapping(
return new_slot_mapping return new_slot_mapping
def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the drafter.
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = replace(
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
)
new: VllmConfig = replace(
old,
quant_config=None,
parallel_config=new_parallel_config,
model_config=old_spec_config.draft_model_config,
)
return new
def extend_all_queries_by_N( def extend_all_queries_by_N(
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
N: int, N: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment