Unverified Commit af3162d3 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Spec Decode] Unified Parallel Drafting (#32887)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent 5b2a9422
...@@ -75,6 +75,7 @@ def parse_args(): ...@@ -75,6 +75,7 @@ def parse_args():
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true") parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None) parser.add_argument("--max-num-seqs", type=int, default=None)
parser.add_argument("--parallel-drafting", action="store_true")
parser.add_argument("--allowed-local-media-path", type=str, default="") parser.add_argument("--allowed-local-media-path", type=str, default="")
return parser.parse_args() return parser.parse_args()
...@@ -121,6 +122,7 @@ def main(args): ...@@ -121,6 +122,7 @@ def main(args):
"model": eagle_dir, "model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens, "num_speculative_tokens": args.num_spec_tokens,
"disable_padded_drafter_batch": args.disable_padded_drafter_batch, "disable_padded_drafter_batch": args.disable_padded_drafter_batch,
"parallel_drafting": args.parallel_drafting,
} }
elif args.method == "ngram": elif args.method == "ngram":
speculative_config = { speculative_config = {
...@@ -137,6 +139,7 @@ def main(args): ...@@ -137,6 +139,7 @@ def main(args):
"num_speculative_tokens": args.num_spec_tokens, "num_speculative_tokens": args.num_spec_tokens,
"enforce_eager": args.enforce_eager, "enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len, "max_model_len": args.max_model_len,
"parallel_drafting": args.parallel_drafting,
} }
elif args.method == "mtp": elif args.method == "mtp":
speculative_config = { speculative_config = {
......
...@@ -13,15 +13,12 @@ from vllm import LLM, SamplingParams ...@@ -13,15 +13,12 @@ from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config.vllm import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.draft_model import ( from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
create_vllm_config_for_draft_model,
merge_toks_kernel,
)
MTP_SIMILARITY_RATE = 0.8 MTP_SIMILARITY_RATE = 0.8
...@@ -625,6 +622,8 @@ class ArgsTest: ...@@ -625,6 +622,8 @@ class ArgsTest:
expected_acceptance_rate: float expected_acceptance_rate: float
expected_acceptance_len: float expected_acceptance_len: float
# Defaults # Defaults
enforce_eager: bool = True
parallel_drafting: bool = False
target_tensor_parallel_size: int = 1 target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1 draft_tensor_parallel_size: int = 1
max_model_len: int = 1024 max_model_len: int = 1024
...@@ -658,7 +657,8 @@ cases = [ ...@@ -658,7 +657,8 @@ cases = [
@pytest.mark.parametrize("args", cases) @pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool): def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
assert_draft_model_correctness(args, enforce_eager) args.enforce_eager = enforce_eager
assert_draft_model_correctness(args)
def test_draft_model_realistic_example(): def test_draft_model_realistic_example():
...@@ -668,11 +668,28 @@ def test_draft_model_realistic_example(): ...@@ -668,11 +668,28 @@ def test_draft_model_realistic_example():
dataset="likaixin/InstructCoder", dataset="likaixin/InstructCoder",
num_speculative_tokens=3, num_speculative_tokens=3,
sampling_config=greedy_sampling(), sampling_config=greedy_sampling(),
enforce_eager=False,
# values below are not derived, but just prevent a regression # values below are not derived, but just prevent a regression
expected_acceptance_len=2.8, expected_acceptance_len=2.8,
expected_acceptance_rate=0.55, expected_acceptance_rate=0.55,
) )
assert_draft_model_correctness(args, enforce_eager=False) assert_draft_model_correctness(args)
def test_draft_model_parallel_drafting():
args = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="amd/PARD-Qwen3-0.6B",
dataset="likaixin/InstructCoder",
num_speculative_tokens=3,
sampling_config=greedy_sampling(),
parallel_drafting=True,
enforce_eager=False,
# values below are collected from a stable run, with ~5% tolerance
expected_acceptance_len=2.375,
expected_acceptance_rate=0.45,
)
assert_draft_model_correctness(args)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -691,8 +708,9 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool): ...@@ -691,8 +708,9 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
target_model=tgt_model, target_model=tgt_model,
draft_model=draft_model, draft_model=draft_model,
**some_high_acceptance_metrics(), **some_high_acceptance_metrics(),
enforce_eager=enforce_eager,
) )
assert_draft_model_correctness(sd_case, enforce_eager) assert_draft_model_correctness(sd_case)
def test_draft_model_tensor_parallelism(): def test_draft_model_tensor_parallelism():
...@@ -704,8 +722,9 @@ def test_draft_model_tensor_parallelism(): ...@@ -704,8 +722,9 @@ def test_draft_model_tensor_parallelism():
draft_model="Qwen/Qwen3-0.6B", draft_model="Qwen/Qwen3-0.6B",
draft_tensor_parallel_size=2, draft_tensor_parallel_size=2,
**some_high_acceptance_metrics(), **some_high_acceptance_metrics(),
enforce_eager=False,
) )
assert_draft_model_correctness(sd_case, enforce_eager=False) assert_draft_model_correctness(sd_case)
def test_draft_model_engine_args_tensor_parallelism(): def test_draft_model_engine_args_tensor_parallelism():
...@@ -750,7 +769,7 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname(): ...@@ -750,7 +769,7 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname():
engine_args.create_engine_config() engine_args.create_engine_config()
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): def assert_draft_model_correctness(args: ArgsTest):
"""Compare the outputs using and not using speculative decoding. """Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY.""" In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts: list[Messages] = get_messages( test_prompts: list[Messages] = get_messages(
...@@ -764,14 +783,15 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): ...@@ -764,14 +783,15 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"method": "draft_model", "method": "draft_model",
"num_speculative_tokens": args.num_speculative_tokens, "num_speculative_tokens": args.num_speculative_tokens,
"max_model_len": args.max_model_len, "max_model_len": args.max_model_len,
"enforce_eager": enforce_eager, "enforce_eager": args.enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size, "draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"parallel_drafting": args.parallel_drafting,
}, },
max_num_seqs=100, # limit cudagraph capture runtime max_num_seqs=100, # limit cudagraph capture runtime
max_model_len=args.max_model_len, max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization, gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size, tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=args.enforce_eager,
disable_log_stats=False, # enables get_metrics() disable_log_stats=False, # enables get_metrics()
) )
# we don't check the outputs, only check the metrics # we don't check the outputs, only check the metrics
...@@ -813,57 +833,6 @@ def some_high_acceptance_metrics() -> dict: ...@@ -813,57 +833,6 @@ def some_high_acceptance_metrics() -> dict:
} }
def test_merge_toks_kernel():
device = "cuda"
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary
is_rejected_tok = torch.full((merged_len,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 3], device=device),
query_end_locs_ptr=torch.tensor([2, 4], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=5,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor([False] * merged_len, device=device)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def test_merge_toks_kernel_with_rejected_tokens():
device = "cuda"
merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2
merged = torch.full((merged_size,), -100, device=device)
is_rejected_tok = torch.full((merged_size,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
# rejected tokens
# ↓ ↓ ↓ ↓
target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 6], device=device),
query_end_locs_ptr=torch.tensor([2, 7], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=9,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor(
[False, False, False, False, True, True, True, False, False, False, True],
device=device,
)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def compute_acceptance_rate(metrics: list[Metric]) -> float: def compute_acceptance_rate(metrics: list[Metric]) -> float:
name2metric = {metric.name: metric for metric in metrics} name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
......
...@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig ...@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -34,6 +35,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch ...@@ -34,6 +35,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
model_dir = "meta-llama/Llama-3.1-8B-Instruct" model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting
def _create_proposer( def _create_proposer(
...@@ -41,11 +43,19 @@ def _create_proposer( ...@@ -41,11 +43,19 @@ def _create_proposer(
num_speculative_tokens: int, num_speculative_tokens: int,
attention_backend: str | None = None, attention_backend: str | None = None,
speculative_token_tree: list[tuple[int, ...]] | None = None, speculative_token_tree: list[tuple[int, ...]] | None = None,
parallel_drafting: bool = False,
) -> EagleProposer: ) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
# Choose model directory based on method # Method-dependent setup
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir if method == "eagle":
draft_model_dir = eagle_dir
elif method == "eagle3":
draft_model_dir = eagle3_dir
elif method == "draft_model":
draft_model_dir = ar_draft_model_dir
else:
raise ValueError(f"Unknown method: {method}")
spec_token_tree_str = None spec_token_tree_str = None
if speculative_token_tree is not None: if speculative_token_tree is not None:
...@@ -59,13 +69,18 @@ def _create_proposer( ...@@ -59,13 +69,18 @@ def _create_proposer(
method=method, method=method,
num_speculative_tokens=num_speculative_tokens, num_speculative_tokens=num_speculative_tokens,
speculative_token_tree=spec_token_tree_str, speculative_token_tree=spec_token_tree_str,
parallel_drafting=parallel_drafting,
) )
if parallel_drafting:
# Overwrite pard_token to avoid crash during init
speculative_config.draft_model_config.hf_config.pard_token = 0
device = current_platform.device_type
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=CacheConfig(), cache_config=CacheConfig(),
speculative_config=speculative_config, speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type), device_config=DeviceConfig(device=device),
parallel_config=ParallelConfig(), parallel_config=ParallelConfig(),
load_config=LoadConfig(), load_config=LoadConfig(),
scheduler_config=SchedulerConfig( scheduler_config=SchedulerConfig(
...@@ -75,7 +90,10 @@ def _create_proposer( ...@@ -75,7 +90,10 @@ def _create_proposer(
attention_config=AttentionConfig(backend=attention_backend), attention_config=AttentionConfig(backend=attention_backend),
) )
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) if "eagle" in method:
return EagleProposer(vllm_config=vllm_config, device=device)
else:
return DraftModelProposer(vllm_config=vllm_config, device=device)
def test_prepare_next_token_ids(): def test_prepare_next_token_ids():
...@@ -321,6 +339,390 @@ def test_prepare_inputs_padded(): ...@@ -321,6 +339,390 @@ def test_prepare_inputs_padded():
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
def test_set_inputs_first_pass_default_eagle():
"""
Test for set_inputs_first_pass without extra input slots (default EAGLE).
This tests the path where needs_extra_input_slots=False, which is the
default EAGLE pathway. In this case:
- Input IDs are rotated (shifted by one)
- The next_token_ids are inserted at the last position of each request
- Positions are copied as-is
- Hidden states are copied as-is
- The CommonAttentionMetadata is returned unchanged
Setup:
- 3 requests with query_lens [3, 2, 4]
- Tokens: [a1, a2, a3, b1, b2, c1, c2, c3, c4]
- After rotation: [a2, a3, -, b2, -, c2, c3, c4, -]
- After inserting next_tokens [100, 200, 300]:
[a2, a3, 100, b2, 200, c2, c3, c4, 300]
"""
device = torch.device(current_platform.device_type)
num_speculative_tokens = 3
proposer = _create_proposer("eagle", num_speculative_tokens)
# Setup batch with 3 requests
batch_spec = BatchSpec(
seq_lens=[10, 8, 12], # Arbitrary context lengths
query_lens=[3, 2, 4],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# Input tensors
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
# Request 1: tokens [20, 21] at positions [6, 7]
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
target_token_ids = torch.tensor(
[10, 11, 12, 20, 21, 30, 31, 32, 33], dtype=torch.int32, device=device
)
target_positions = torch.tensor(
[7, 8, 9, 6, 7, 8, 9, 10, 11], dtype=torch.int64, device=device
)
target_hidden_states = torch.randn(
9, proposer.hidden_size, dtype=proposer.dtype, device=device
)
next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device)
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=None,
)
assert num_tokens == 9 # Total tokens unchanged
expected_token_indices_to_sample = torch.tensor(
[2, 4, 8], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
assert output_cad is common_attn_metadata
# Verify input_ids are rotated and next_tokens inserted
# Original: [10, 11, 12, 20, 21, 30, 31, 32, 33]
# After shift by 1: [11, 12, 12, 21, 21, 31, 32, 33, 33]
# After inserting at last indices [2, 4, 8]: [11, 12, 100, 21, 200, 31, 32, 33, 300]
expected_input_ids = torch.tensor(
[11, 12, 100, 21, 200, 31, 32, 33, 300], dtype=torch.int32, device=device
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
# Verify positions are copied as-is
assert torch.equal(proposer.positions[:num_tokens], target_positions)
# Verify hidden states are copied as-is
assert torch.equal(proposer.hidden_states[:num_tokens], target_hidden_states)
def test_set_inputs_first_pass_draft_model():
"""
Test for set_inputs_first_pass with a draft model (extra input slots,
no shift).
This tests the path where needs_extra_input_slots=True and
shift_input_ids=False (draft model case). In this case:
- Input IDs are NOT shifted
- Each request gets extra_slots_per_request (1) new slots
- The kernel handles copying tokens and inserting bonus/padding tokens
- A new CommonAttentionMetadata is returned with updated query_start_loc
Setup:
- 2 requests
- Request 0: tokens [10, 11, 12] at positions [0, 1, 2]
- Only tokens [10, 11] are "valid" (query_end_loc=1),
token 12 is a rejected token from previous speculation
- Request 1: tokens [20, 21] at positions [0, 1], both valid.
- Note: this is less than num_speculative_tokens (2) to ensure
we handle variable lengths correctly.
- next_token_ids: [100, 200] (bonus tokens)
With extra_slots_per_request=1 and shift=False:
Expected output layout:
Request 0 (indices 0-3):
- idx 0: token 10, pos 0
- idx 1: token 11, pos 1
- idx 2: token 100, pos 2 (bonus token)
- idx 3: padding_token_id, is_rejected=True
Request 1 (indices 4-6):
- idx 4: token 20, pos 0
- idx 5: token 21, pos 1
- idx 6: token 200, pos 2 (bonus token)
"""
device = torch.device(current_platform.device_type)
num_speculative_tokens = 2
block_size = 16
# Create a proposer configured as a draft model (pass_hidden_states=False)
# We need to mock this since _create_proposer defaults to EAGLE
proposer = _create_proposer("draft_model", num_speculative_tokens)
proposer.parallel_drafting_token_id = 0
proposer.is_rejected_token_mask = torch.zeros(
proposer.max_num_tokens, dtype=torch.bool, device=device
)
proposer.is_masked_token_mask = torch.zeros(
proposer.max_num_tokens, dtype=torch.bool, device=device
)
# Mock the attn_metadata_builder to avoid needing the full model setup
mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder
# Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2
batch_spec = BatchSpec(
seq_lens=[3, 2],
query_lens=[3, 2],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=block_size,
device=device,
arange_block_indices=True, # Use predictable block indices
)
# Input tensors
target_token_ids = torch.tensor(
[10, 11, 12, 20, 21], dtype=torch.int32, device=device
)
target_positions = torch.tensor([0, 1, 2, 0, 1], dtype=torch.int64, device=device)
target_hidden_states = torch.randn(
5, proposer.hidden_size, dtype=proposer.dtype, device=device
)
next_token_ids = torch.tensor([100, 200], dtype=torch.int32, device=device)
num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device)
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
assert proposer.net_num_new_slots_per_request == 1
assert proposer.needs_extra_input_slots
# total_output_tokens = total_input_tokens + net_num_new_slots * batch_size
assert num_tokens == 7
# Request 0: [10, 11, 100, padding_token (0)]
# Request 1: [20, 21, 200]
# Combined: [10, 11, 100, 0, 20, 21, 200]
expected_input_ids = torch.tensor(
[10, 11, 100, 0, 20, 21, 200], dtype=torch.int32, device=device
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
# Verify positions
# Request 0: [0, 1, 2, 0 (don't care)]
# Request 1: [0, 1, 2]
# Combined: [0, 1, 2, 0, 0, 1, 2]
expected_positions = torch.tensor(
[0, 1, 2, 0, 0, 1, 2], dtype=torch.int64, device=device
)
assert torch.equal(
proposer.positions[:num_tokens],
expected_positions,
)
# Verify rejection mask
expected_is_rejected = torch.zeros(7, dtype=torch.bool, device=device)
expected_is_rejected[3] = True # padding token at index 3
assert torch.equal(
proposer.is_rejected_token_mask[:num_tokens], expected_is_rejected
)
# Verify masked token mask (should all be False for non-parallel drafting)
expected_is_masked = torch.zeros(7, dtype=torch.bool, device=device)
assert torch.equal(proposer.is_masked_token_mask[:num_tokens], expected_is_masked)
# Verify token_indices_to_sample (bonus tokens at indices 2 and 6)
expected_token_indices_to_sample = torch.tensor(
[2, 6], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
# Verify the new CAD has updated query_start_loc
# Original: [0, 3, 5] -> New: [0, 4, 7] (each request gains 1 slot)
expected_query_start_loc = torch.tensor([0, 4, 7], dtype=torch.int32, device=device)
assert torch.equal(output_cad.query_start_loc, expected_query_start_loc)
def test_set_inputs_first_pass_parallel_drafting():
"""
Test for set_inputs_first_pass with parallel drafting (extra input slots,
with shift).
This tests the path where needs_extra_input_slots=True and
shift_input_ids=True (parallel drafting case). In this case:
- Input IDs ARE shifted (like default EAGLE)
- Each request gets extra_slots_per_request (3) new slots
- Parallel drafting tokens are inserted and marked as masked
- Hidden states are mapped correctly
Setup:
- 2 requests with query_lens [4, 4] (1 bonus + 3 spec tokens each)
- Request 0: tokens [10, 11, 12, 13] at positions [5, 6, 7, 8]
- Only tokens [10, 11, 12] are "valid", token 13 is rejected
- Request 1: tokens [20, 21, 22, 23] at positions [10, 11, 12, 13], all valid.
- next_token_ids: [100, 200] (bonus tokens)
With shift_input_ids=True, extra_slots_per_request=3:
Expected output layout:
Request 0 (6 output slots = 4 - 1 + 3):
- idx 0-2: shifted tokens [11, 12, 100]
- idx 3-4: parallel_drafting_tokens, is_masked=True
- idx 5: padding_token, is_rejected=True
Request 1 (6 output slots = 4 - 1 + 3):
- idx 6-8: shifted tokens [21, 22, 23]
- idx 9: bonus token 200
- idx 10-11: parallel_drafting_tokens, is_masked=True
"""
device = torch.device(current_platform.device_type)
num_speculative_tokens = 3
block_size = 16
proposer = _create_proposer("eagle", num_speculative_tokens, parallel_drafting=True)
# Override to simulate parallel drafting behavior
proposer.parallel_drafting_token_id = -2
proposer.parallel_drafting_hidden_state_tensor = torch.zeros(
proposer.hidden_size, dtype=proposer.dtype, device=device
)
proposer.is_rejected_token_mask = torch.zeros(
proposer.max_num_tokens, dtype=torch.bool, device=device
)
proposer.is_masked_token_mask = torch.zeros(
proposer.max_num_tokens, dtype=torch.bool, device=device
)
# Mock the attn_metadata_builder
mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder
# Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid)
batch_spec = BatchSpec(
seq_lens=[9, 14],
query_lens=[4, 4],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=block_size,
device=device,
arange_block_indices=True,
)
# Input tensors
target_token_ids = torch.tensor(
[10, 11, 12, 13, 20, 21, 22, 23], dtype=torch.int32, device=device
)
target_positions = torch.tensor(
[5, 6, 7, 8, 10, 11, 12, 13], dtype=torch.int64, device=device
)
target_hidden_states = torch.arange(
8 * proposer.hidden_size, dtype=proposer.dtype, device=device
).view(8, proposer.hidden_size)
next_token_ids = torch.tensor([100, 200], dtype=torch.int32, device=device)
num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device)
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
# total_output_tokens = total_input_tokens + net_num_new_slots * batch_size
# = 8 + 2 * 2 = 12
assert num_tokens == 12
# Request 0: [11, 12, 100, -2, -2, 0(padding)]
# Request 1: [21, 22, 23, 200, -2, -2]
expected_input_ids = torch.tensor(
[11, 12, 100, -2, -2, 0, 21, 22, 23, 200, -2, -2],
dtype=torch.int32,
device=device,
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
# Verify positions
# Request 0: [5, 6, 7, 8, 9, 0 (don't care)]
# Request 1: [10, 11, 12, 13, 14, 15]
expected_positions = torch.tensor(
[5, 6, 7, 8, 9, 0, 10, 11, 12, 13, 14, 15], dtype=torch.int64, device=device
)
assert torch.equal(
proposer.positions[:num_tokens],
expected_positions,
)
# Verify rejection mask
expected_is_rejected = torch.zeros(12, dtype=torch.bool, device=device)
expected_is_rejected[5] = True
assert torch.equal(
proposer.is_rejected_token_mask[:num_tokens], expected_is_rejected
)
# Verify masked token mask (parallel drafting slots should be masked)
expected_is_masked = torch.zeros(12, dtype=torch.bool, device=device)
expected_is_masked[3] = True
expected_is_masked[4] = True
expected_is_masked[10] = True
expected_is_masked[11] = True
assert torch.equal(proposer.is_masked_token_mask[:num_tokens], expected_is_masked)
# Verify token_indices_to_sample (bonus + parallel drafting tokens)
# Request 0: bonus at 2, parallel at 3, 4
# Request 1: bonus at 9, parallel at 10, 11
expected_token_indices_to_sample = torch.tensor(
[2, 3, 4, 9, 10, 11], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
# Verify the new CAD has updated query_start_loc
# Original query_lens: [4, 4] -> Output: [6, 6]
expected_query_start_loc = torch.tensor(
[0, 6, 12], dtype=torch.int32, device=device
)
assert torch.equal(output_cad.query_start_loc, expected_query_start_loc)
# Verify masked positions have the parallel drafting hidden state (zeros)
parallel_drafting_hs = proposer.parallel_drafting_hidden_state_tensor
for i in range(num_tokens):
if expected_is_masked[i]:
assert torch.equal(proposer.hidden_states[i], parallel_drafting_hs), (
f"Masked position {i} should have parallel drafting hidden state"
)
@pytest.mark.parametrize("method", ["eagle", "eagle3"]) @pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("pp_size", [1, 2])
...@@ -579,7 +981,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -579,7 +981,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
last_token_indices=None, token_indices_to_sample=None,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
...@@ -737,7 +1139,7 @@ def test_propose_tree(spec_token_tree): ...@@ -737,7 +1139,7 @@ def test_propose_tree(spec_token_tree):
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
last_token_indices=None, token_indices_to_sample=None,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
......
...@@ -204,7 +204,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch): ...@@ -204,7 +204,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
last_token_indices=None, token_indices_to_sample=None,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
......
...@@ -116,9 +116,16 @@ class SpeculativeConfig: ...@@ -116,9 +116,16 @@ class SpeculativeConfig:
"""Minimum size of ngram token window when using Ngram proposer, if """Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1.""" provided. Defaults to 1."""
# Alternative drafting strategies
speculative_token_tree: str | None = None speculative_token_tree: str | None = None
"""Specifies the tree structure for speculative token generation. """Specifies the tree structure for speculative token generation.
""" """
parallel_drafting: bool = False
"""Enable parallel drafting, where all speculative tokens are generated
in parallel rather than sequentially. This can improve performance but
requires the speculative model be trained to support parallel drafting.
Only compatible with EAGLE and draft model methods."""
# required configuration params passed from engine # required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model.""" """The configuration of the target model."""
......
...@@ -604,10 +604,13 @@ class VllmConfig: ...@@ -604,10 +604,13 @@ class VllmConfig:
# Currently, async scheduling only support eagle speculative # Currently, async scheduling only support eagle speculative
# decoding. # decoding.
if self.speculative_config is not None: if self.speculative_config is not None:
if self.speculative_config.method not in get_args(EagleModelTypes): if (
self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method != "draft_model"
):
raise ValueError( raise ValueError(
"Currently, async scheduling is only supported " "Currently, async scheduling is only supported "
"with EAGLE/MTP kind of speculative decoding." "with EAGLE/MTP/Draft Model kind of speculative decoding."
) )
if self.speculative_config.disable_padded_drafter_batch: if self.speculative_config.disable_padded_drafter_batch:
raise ValueError( raise ValueError(
...@@ -1298,16 +1301,21 @@ class VllmConfig: ...@@ -1298,16 +1301,21 @@ class VllmConfig:
computed_compile_ranges_split_points = [] computed_compile_ranges_split_points = []
# The upper bound of the compile ranges is the max_num_batched_tokens. # The upper bound of the compile ranges is the max_num_batched_tokens.
# For speculative decoding with draft model, the compile range must be extended # For speculative decoding, the compile range must be extended
# by 1 for each sequence. # - Sequential: + 1 * max_num_seqs (one draft token per iteration)
# - Parallel draft: + num_speculative_tokens * max_num_seqs
compile_range_end = self.scheduler_config.max_num_batched_tokens compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None: if compile_range_end is not None:
do_extend: bool = ( if self.speculative_config is not None and (
self.speculative_config is not None self.speculative_config.uses_draft_model()
and self.speculative_config.uses_draft_model() or self.speculative_config.use_eagle()
):
multiplier = (
self.speculative_config.num_speculative_tokens
if self.speculative_config.parallel_drafting
else 1
) )
if do_extend: compile_range_end += multiplier * self.scheduler_config.max_num_seqs
compile_range_end += self.scheduler_config.max_num_seqs
computed_compile_ranges_split_points.append(compile_range_end) computed_compile_ranges_split_points.append(compile_range_end)
......
...@@ -52,13 +52,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer): ...@@ -52,13 +52,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
# Subsequent layers use hidden_size (only hidden_states, no embeds) # Subsequent layers use hidden_size (only hidden_states, no embeds)
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size
# override qkv # Parallel drafting checkpoints may have attention bias enabled
qkv_bias = getattr(config, "attention_bias", False)
# Override qkv_proj with correct input size and bias setting
self.self_attn.qkv_proj = QKVParallelLinear( self.self_attn.qkv_proj = QKVParallelLinear(
qkv_input_size, qkv_input_size,
self.self_attn.head_dim, self.self_attn.head_dim,
self.self_attn.total_num_heads, self.self_attn.total_num_heads,
self.self_attn.total_num_kv_heads, self.self_attn.total_num_kv_heads,
bias=False, bias=qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "qkv_proj"), prefix=maybe_prefix(prefix, "qkv_proj"),
) )
...@@ -293,6 +296,19 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -293,6 +296,19 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
requires_grad=False, requires_grad=False,
) )
self.use_parallel_drafting = vllm_config.speculative_config.parallel_drafting
if self.use_parallel_drafting:
self.register_buffer(
"mask_hidden",
torch.zeros(
1,
(3 if self.model.use_aux_hidden_state else 1)
* self.config.hidden_size,
),
persistent=False,
)
def embed_input_ids( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -347,12 +363,25 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -347,12 +363,25 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
model_weights = {} model_weights = {}
includes_draft_id_mapping = False includes_draft_id_mapping = False
includes_embed_tokens = False includes_embed_tokens = False
includes_mask_hidden = False
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "t2d" in name: if "t2d" in name:
continue continue
if "d2t" in name: if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id") name = name.replace("d2t", "draft_id_to_target_id")
includes_draft_id_mapping = True includes_draft_id_mapping = True
elif "mask_hidden" in name:
# Load mask_hidden directly into buffer
if not self.use_parallel_drafting:
logger.warning(
"mask_hidden found in weights but "
"model is not configured for parallel drafting. "
"Skipping loading mask_hidden."
)
continue
self.mask_hidden.copy_(loaded_weight.view(1, -1))
includes_mask_hidden = True
continue
elif "lm_head" not in name: elif "lm_head" not in name:
name = "model." + name name = "model." + name
if "embed_tokens" in name: if "embed_tokens" in name:
...@@ -360,7 +389,14 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -360,7 +389,14 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
model_weights[name] = loaded_weight model_weights[name] = loaded_weight
process_eagle_weight(self, name) process_eagle_weight(self, name)
skip_substrs = [] if not includes_mask_hidden and self.use_parallel_drafting:
raise ValueError(
"mask_hidden not found in weights but "
"model is configured for parallel drafting. "
"Please provide mask_hidden in the weights."
)
skip_substrs = ["mask_hidden"]
if not includes_draft_id_mapping: if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id") skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens: if not includes_embed_tokens:
......
...@@ -480,9 +480,14 @@ class AttentionMetadataBuilder(ABC, Generic[M]): ...@@ -480,9 +480,14 @@ class AttentionMetadataBuilder(ABC, Generic[M]):
speculative_config is not None speculative_config is not None
and speculative_config.num_speculative_tokens is not None and speculative_config.num_speculative_tokens is not None
): ):
max_num_queries_for_spec = (
1
+ (2 if speculative_config.parallel_drafting else 1)
* speculative_config.num_speculative_tokens
)
self.reorder_batch_threshold = max( self.reorder_batch_threshold = max(
self.reorder_batch_threshold, self.reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens, max_num_queries_for_spec,
) )
if ( if (
......
...@@ -60,7 +60,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -60,7 +60,7 @@ from vllm.v1.attention.backends.utils import (
) )
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
...@@ -658,12 +658,36 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -658,12 +658,36 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
vllm_config: VllmConfig, vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport: ) -> AttentionCGSupport:
has_trtllm_support = can_use_trtllm_attention( """Get the cudagraph support level for FlashInfer attention.
num_qo_heads=vllm_config.model_config.get_num_attention_heads(
This depends on whether we can use TRTLLM attention for decodes, since we can
only do UNIFORM_SINGLE_TOKEN_DECODE if it is unavailable.
To check this, we must call can_use_trtllm_attention with the number of KV
heads from the kv_cache_spec. We check all available KV cache specs and
only return UNIFORM_BATCH if all of them support TRTLLM attention.
"""
# For UniformTypeKVCacheSpecs, check all contained specs
kv_specs = (
kv_cache_spec.kv_cache_specs.values()
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs)
else [kv_cache_spec]
)
num_qo_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config vllm_config.parallel_config
),
num_kv_heads=kv_cache_spec.num_kv_heads,
) )
has_trtllm_support: bool = len(kv_specs) > 0
for spec in kv_specs:
if not isinstance(spec, AttentionSpec):
# FlashInfer only applies to attention, so we don't consider other types
# of KV spec (e.g. Mamba) here. This is mostly for type checking.
continue
if not can_use_trtllm_attention(
num_qo_heads=num_qo_heads,
num_kv_heads=spec.num_kv_heads,
):
has_trtllm_support = False
break
if has_trtllm_support: if has_trtllm_support:
return AttentionCGSupport.UNIFORM_BATCH return AttentionCGSupport.UNIFORM_BATCH
else: else:
......
...@@ -825,38 +825,6 @@ def get_dcp_local_seq_lens( ...@@ -825,38 +825,6 @@ def get_dcp_local_seq_lens(
return dcp_local_seq_lens.squeeze(1) return dcp_local_seq_lens.squeeze(1)
def extend_all_queries_by_1(
common_attn_metadata: CommonAttentionMetadata,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
Also all seq lens are increased by 1.
This is useful e.g. in speculative decoding with draft models, where we
extend each sequence by 1 token.
The slot mapping is computed externally, as it requires more information.
"""
cad = common_attn_metadata
# query start loc must be increased by [+0, +1, +2, ..., +batch_size]
new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)]
new_query_start_loc_cpu = cad.query_start_loc_cpu + torch.arange(
len(cad.query_start_loc_cpu), dtype=torch.int32
)
new_cad = cad.replace(
query_start_loc=new_query_start_loc,
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=cad.seq_lens + 1,
# each request is extended by 1 token -> batch_size tokens are added
num_actual_tokens=cad.num_actual_tokens + cad.batch_size(),
# All query lens increase by 1, so max query len increases by 1
max_query_len=cad.max_query_len + 1,
max_seq_len=cad.max_seq_len + 1,
slot_mapping=new_slot_mapping,
)
return new_cad
def mamba_get_block_table_tensor( def mamba_get_block_table_tensor(
block_table: torch.Tensor, block_table: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch import torch
import torch.nn as nn
from typing_extensions import override
from vllm.config import VllmConfig, get_layers_from_vllm_config, replace from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
from vllm.v1.attention.backends.utils import ( from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
CommonAttentionMetadata,
extend_all_queries_by_1,
)
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,37 +27,9 @@ class DraftModelProposer(SpecDecodeBaseProposer): ...@@ -31,37 +27,9 @@ class DraftModelProposer(SpecDecodeBaseProposer):
pass_hidden_states_to_model=False, pass_hidden_states_to_model=False,
runner=runner, runner=runner,
) )
self._raise_if_multimodal()
self._raise_if_mrope()
self._raise_if_padded_drafter_batch_disabled()
self._raise_if_vocab_size_mismatch() self._raise_if_vocab_size_mismatch()
self._raise_if_draft_tp_mismatch() self._raise_if_draft_tp_mismatch()
def _block_size(self) -> int:
builder = self._get_attention_metadata_builder()
return builder.kv_cache_spec.block_size
def _raise_if_multimodal(self):
if self.supports_mm_inputs:
raise NotImplementedError(
"Speculative Decoding with draft models "
"does not support multimodal models yet"
)
def _raise_if_mrope(self):
if self.draft_model_config.uses_mrope:
raise NotImplementedError(
"Speculative Decoding with draft models does not support M-RoPE yet"
)
def _raise_if_padded_drafter_batch_disabled(self):
if self.speculative_config.disable_padded_drafter_batch:
raise NotImplementedError(
"Speculative Decoding with draft models only supports "
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
" in the speculative_config."
)
def _raise_if_vocab_size_mismatch(self): def _raise_if_vocab_size_mismatch(self):
self.speculative_config.verify_equal_vocab_size_if_draft_model() self.speculative_config.verify_equal_vocab_size_if_draft_model()
...@@ -82,193 +50,26 @@ class DraftModelProposer(SpecDecodeBaseProposer): ...@@ -82,193 +50,26 @@ class DraftModelProposer(SpecDecodeBaseProposer):
"Please pass 'draft_tensor_parallel_size' in the speculative_config." "Please pass 'draft_tensor_parallel_size' in the speculative_config."
) )
def set_inputs_first_pass( @override
self, def _get_model(self) -> nn.Module:
target_token_ids: torch.Tensor, # Draft models may be quantized or on different parallelism,
next_token_ids: torch.Tensor, # so we load them with a modified vllm config
target_positions: torch.Tensor,
last_token_indices: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
batch_size = cad.batch_size()
grid = (batch_size,)
start_locs = cad.query_start_loc[:-1]
end_locs = cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
end_locs -= num_rejected_tokens_gpu
num_tokens = target_token_ids.shape[0] + batch_size
is_rejected_tok = torch.empty(
(num_tokens,), device=self.input_ids.device, dtype=torch.bool
)
merge_toks_kernel[grid](
target_toks_ptr=target_token_ids,
next_toks_ptr=next_token_ids,
query_start_locs_ptr=start_locs,
query_end_locs_ptr=end_locs,
out_ptr_merged_toks=self.input_ids,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=target_token_ids.shape[0],
# passing a negative rejected_tok_fill value will raise an error
# when the value is used to index into embeddings.
# Therefore, we pass a valid integer, e.g. 0.
rejected_tok_fill=0,
)
merge_toks_kernel[grid](
target_toks_ptr=target_positions,
next_toks_ptr=target_positions[end_locs] + 1,
query_start_locs_ptr=start_locs,
query_end_locs_ptr=end_locs,
out_ptr_merged_toks=self.positions,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=target_positions.shape[0],
rejected_tok_fill=0,
)
# recompute slot mapping
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:num_tokens],
is_rejected_token_mask=is_rejected_tok,
block_size=self._block_size(),
max_model_len=self.max_model_len,
)
# update common_attn_metadata
new_cad: CommonAttentionMetadata = extend_all_queries_by_1(
cad,
arange=self.arange,
new_slot_mapping=new_slot_mapping,
)
new_last_token_indices = new_cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
new_last_token_indices -= num_rejected_tokens_gpu
return num_tokens, new_last_token_indices, new_cad
def load_model(self, target_model: Any) -> None:
"""Takes target_model to satisfy the type checker."""
# This must be computed before loading the draft model
# because that mutates the forward_context of the vllm_config
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
)
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model( temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config)
target_model_vllm_config=self.vllm_config
)
logger.info(
"Starting to load draft model %s. TP=%d, rank=%d",
draft_vllm_config.model_config.model,
draft_vllm_config.parallel_config.tensor_parallel_size,
draft_vllm_config.parallel_config.rank,
)
with set_model_tag("draft_model"): with set_model_tag("draft_model"):
self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model") model = get_model(
vllm_config=temp_vllm_config,
# This must be computed after loading the draft model prefix="draft_model",
# because that mutates the forward_context of the vllm_config
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
- target_attn_layer_names
)
self.attn_layer_names = list(draft_attn_layer_names)
def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the draft model.
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = replace(
old_spec_config.draft_parallel_config,
rank=old.parallel_config.rank,
)
new: VllmConfig = replace(
old,
quant_config=None, # quant_config is recomputed in __init__()
model_config=old_spec_config.draft_model_config,
parallel_config=new_parallel_config,
)
return new
def compute_new_slot_mapping(
cad: CommonAttentionMetadata,
new_positions: torch.Tensor,
is_rejected_token_mask: torch.Tensor,
block_size: int,
max_model_len: int,
):
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
req_indices = torch.repeat_interleave(
req_indices, cad.naive_query_lens() + 1, output_size=len(new_positions)
) )
# Clamp the positions to prevent an out-of-bounds error when indexing return model
# into block_table_tensor.
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
block_table_indices = (
req_indices * n_blocks_per_req + clamped_positions // block_size
)
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
block_offsets = clamped_positions % block_size
new_slot_mapping = block_nums * block_size + block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len = new_positions >= max_model_len
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
return new_slot_mapping
@triton.jit @override
def merge_toks_kernel( def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
target_toks_ptr, # Draft models don't share embeddings with the target model
next_toks_ptr, pass
query_start_locs_ptr,
query_end_locs_ptr,
out_ptr_merged_toks,
out_ptr_is_rejected_tok,
target_toks_size,
rejected_tok_fill,
):
"""
Merges the `target_toks_ptr` and the `next_toks_ptr` into a new tensor
called `out_ptr_merged_toks`. Rejected tokens are those after the
`query_end_locs_ptr` and before the next `query_start_locs_ptr`. Fills the
rejected tokens positions with the value `rejected_tok_fill`. Also fills a mask
of the rejected tokens in `out_ptr_is_rejected_tok`.
"""
pid = tl.program_id(0)
start_loc = tl.load(query_start_locs_ptr + pid)
is_last_program = pid == tl.num_programs(0) - 1
if is_last_program:
next_start_loc = target_toks_size.to(tl.int32)
else:
next_start_loc = tl.load(query_start_locs_ptr + pid + 1).to(tl.int32)
end_loc = tl.load(query_end_locs_ptr + pid) @override
new_val = tl.load(next_toks_ptr + pid) def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
for i in range(start_loc, next_start_loc + 1): # Draft models don't share lm_head with the target model
if i <= end_loc: # copy existing tokens pass
old_val = tl.load(target_toks_ptr + i)
tl.store(out_ptr_merged_toks + pid + i, old_val)
tl.store(out_ptr_is_rejected_tok + pid + i, False)
elif i == end_loc + 1: # copy bonus token
tl.store(out_ptr_merged_toks + pid + i, new_val)
tl.store(out_ptr_is_rejected_tok + pid + i, False)
else: # fill rejected tokens
tl.store(out_ptr_merged_toks + pid + i, rejected_tok_fill)
tl.store(out_ptr_is_rejected_tok + pid + i, True)
...@@ -43,8 +43,12 @@ from vllm.v1.sample.metadata import SamplingMetadata ...@@ -43,8 +43,12 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.utils import ( from vllm.v1.spec_decode.utils import (
PADDING_SLOT_ID,
compute_new_slot_mapping,
copy_and_expand_eagle_inputs_kernel,
eagle_prepare_inputs_padded_kernel, eagle_prepare_inputs_padded_kernel,
eagle_prepare_next_token_padded_kernel, eagle_prepare_next_token_padded_kernel,
extend_all_queries_by_N,
) )
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
...@@ -52,8 +56,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch ...@@ -52,8 +56,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
logger = init_logger(__name__) logger = init_logger(__name__)
PADDING_SLOT_ID = -1
class SpecDecodeBaseProposer: class SpecDecodeBaseProposer:
def __init__( def __init__(
...@@ -76,18 +78,35 @@ class SpecDecodeBaseProposer: ...@@ -76,18 +78,35 @@ class SpecDecodeBaseProposer:
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
# The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because # We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's # the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B). # hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size() self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
# Unifying eagle, draft model, and parallel drafting support
self.parallel_drafting: bool = self.speculative_config.parallel_drafting
self.extra_slots_per_request = (
1 if not self.parallel_drafting else self.num_speculative_tokens
)
self.net_num_new_slots_per_request = self.extra_slots_per_request - (
1 if self.pass_hidden_states_to_model else 0
)
self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0
self.parallel_drafting_token_id: int = 0
self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
if self.parallel_drafting:
self._init_parallel_drafting_params()
# The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + (
self.net_num_new_slots_per_request * max_batch_size
)
self.token_arange_np = np.arange(self.max_num_tokens)
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
...@@ -155,6 +174,26 @@ class SpecDecodeBaseProposer: ...@@ -155,6 +174,26 @@ class SpecDecodeBaseProposer:
max_num_slots_for_arange, device=device, dtype=torch.int32 max_num_slots_for_arange, device=device, dtype=torch.int32
) )
if self.needs_extra_input_slots:
self._raise_if_padded_drafter_batch_disabled()
self._raise_if_multimodal()
self._raise_if_mrope()
self.is_rejected_token_mask: torch.Tensor | None = None
self.is_masked_token_mask: torch.Tensor | None = None
if self.needs_extra_input_slots:
# For draft models and parallel drafting, we need to keep track of
# which tokens are rejected to update the slot mapping with padding slots.
self.is_rejected_token_mask = torch.zeros(
(self.max_num_tokens,), dtype=torch.bool, device=device
)
# For parallel drafting, we also need to keep track of which tokens
# are parallel-padding tokens used to sample at later positions.
# We populate this tensor even when using draft models for simplicity.
self.is_masked_token_mask = torch.zeros(
(self.max_num_tokens,), dtype=torch.bool, device=device
)
self.inputs_embeds = torch.zeros( self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.inputs_embeds_size), (self.max_num_tokens, self.inputs_embeds_size),
dtype=self.dtype, dtype=self.dtype,
...@@ -231,6 +270,49 @@ class SpecDecodeBaseProposer: ...@@ -231,6 +270,49 @@ class SpecDecodeBaseProposer:
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
).repeat(max_batch_size, 1) ).repeat(max_batch_size, 1)
def _raise_if_padded_drafter_batch_disabled(self):
if self.speculative_config.disable_padded_drafter_batch:
raise NotImplementedError(
"Speculative Decoding with draft models or parallel drafting only "
"supports padded drafter batch. Please unset "
"disable_padded_drafter_batch in the speculative_config."
)
def _raise_if_multimodal(self):
if self.supports_mm_inputs:
raise NotImplementedError(
"Speculative Decoding with draft models or parallel drafting "
"does not support multimodal models yet"
)
def _raise_if_mrope(self):
if self.draft_model_config.uses_mrope:
raise NotImplementedError(
"Speculative Decoding with draft models or parallel drafting "
"does not support M-RoPE yet"
)
def _init_parallel_drafting_params(self):
# For parallel drafting, we need the token ID to use for masked slots
# And for EAGLE + parallel drafting, we need the hidden state tensor to use
# for those masked slots.
model_hf_config = self.draft_model_config.hf_config
if hasattr(model_hf_config, "pard_token"):
self.parallel_drafting_token_id = model_hf_config.pard_token
elif hasattr(model_hf_config, "ptd_token_id"):
self.parallel_drafting_token_id = model_hf_config.ptd_token_id
else:
raise ValueError(
"For parallel drafting, the draft model config must have "
"`pard_token` or `ptd_token_id` specified in its config.json."
)
if self.pass_hidden_states_to_model:
self.parallel_drafting_hidden_state_tensor = torch.empty(
self.hidden_size, dtype=self.dtype, device=self.device
)
def _get_positions(self, num_tokens: int): def _get_positions(self, num_tokens: int):
if self.uses_mrope: if self.uses_mrope:
return self.mrope_positions[:, :num_tokens] return self.mrope_positions[:, :num_tokens]
...@@ -296,7 +378,7 @@ class SpecDecodeBaseProposer: ...@@ -296,7 +378,7 @@ class SpecDecodeBaseProposer:
target_hidden_states: torch.Tensor, target_hidden_states: torch.Tensor,
# [batch_size] # [batch_size]
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
last_token_indices: torch.Tensor | None, token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
...@@ -314,12 +396,13 @@ class SpecDecodeBaseProposer: ...@@ -314,12 +396,13 @@ class SpecDecodeBaseProposer:
) )
assert target_hidden_states.shape[-1] == self.hidden_size assert target_hidden_states.shape[-1] == self.hidden_size
num_tokens, last_token_indices, common_attn_metadata = ( num_tokens, token_indices_to_sample, common_attn_metadata = (
self.set_inputs_first_pass( self.set_inputs_first_pass(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
target_positions=target_positions, target_positions=target_positions,
last_token_indices=last_token_indices, target_hidden_states=target_hidden_states,
token_indices_to_sample=token_indices_to_sample,
cad=common_attn_metadata, cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu, num_rejected_tokens_gpu=num_rejected_tokens_gpu,
) )
...@@ -366,11 +449,6 @@ class SpecDecodeBaseProposer: ...@@ -366,11 +449,6 @@ class SpecDecodeBaseProposer:
if num_tokens_across_dp is not None: if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens num_tokens_across_dp[self.dp_rank] = num_input_tokens
if self.pass_hidden_states_to_model:
# target_hidden_states and self.hidden_states can have different
# hidden dims. E.g. large target model and small draft model.
self.hidden_states[:num_tokens] = target_hidden_states
if self.supports_mm_inputs: if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
...@@ -411,27 +489,27 @@ class SpecDecodeBaseProposer: ...@@ -411,27 +489,27 @@ class SpecDecodeBaseProposer:
else: else:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[token_indices_to_sample]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1: if self.num_speculative_tokens == 1 or self.parallel_drafting:
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
return draft_token_ids.view(-1, 1) return draft_token_ids.view(-1, self.num_speculative_tokens)
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, last_token_indices] positions = self.mrope_positions[:, token_indices_to_sample]
else: else:
positions = self.positions[last_token_indices] positions = self.positions[token_indices_to_sample]
if self.method in ( if self.method in (
"deepseek_mtp", "deepseek_mtp",
"ernie_mtp", "ernie_mtp",
"longcat_flash_mtp", "longcat_flash_mtp",
"pangu_ultra_moe_mtp", "pangu_ultra_moe_mtp",
): ):
hidden_states = self.hidden_states[last_token_indices] hidden_states = self.hidden_states[token_indices_to_sample]
else: else:
hidden_states = hidden_states[last_token_indices] hidden_states = hidden_states[token_indices_to_sample]
if isinstance(attn_metadata, TreeAttentionMetadata): if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention. # Draft using tree attention.
...@@ -624,12 +702,17 @@ class SpecDecodeBaseProposer: ...@@ -624,12 +702,17 @@ class SpecDecodeBaseProposer:
target_token_ids: torch.Tensor, target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
target_positions: torch.Tensor, target_positions: torch.Tensor,
last_token_indices: torch.Tensor | None, target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
cad: CommonAttentionMetadata, cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None, num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
if last_token_indices is None: if not self.needs_extra_input_slots:
last_token_indices = cad.query_start_loc[1:] - 1 # Default EAGLE pathway: no reshaping of input tensors needed.
# Simply rotate the input ids and leave the positions unchanged,
# Inserting the next token ids at the last slot in each request.
if token_indices_to_sample is None:
token_indices_to_sample = cad.query_start_loc[1:] - 1
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
# Shift the input ids by one token. # Shift the input ids by one token.
...@@ -637,14 +720,121 @@ class SpecDecodeBaseProposer: ...@@ -637,14 +720,121 @@ class SpecDecodeBaseProposer:
self.input_ids[: num_tokens - 1] = target_token_ids[1:] self.input_ids[: num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token. # Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids self.input_ids[token_indices_to_sample] = next_token_ids
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0: if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
target_positions = target_positions[0] target_positions = target_positions[0]
self._set_positions(num_tokens, target_positions) self._set_positions(num_tokens, target_positions)
return num_tokens, last_token_indices, cad self.hidden_states[:num_tokens] = target_hidden_states
return num_tokens, token_indices_to_sample, cad
else:
assert self.is_rejected_token_mask is not None
assert self.is_masked_token_mask is not None
# 1.
# Call a custom triton kernel to copy input_ids and positions
# into the correct slots in the preallocated buffers self.input_ids,
# self.positions.
batch_size = cad.batch_size()
# Since we might have to copy a lot of data for prefills, we select the
# block size based on the max query length and limit to max 256 slots/block.
max_num_tokens_per_request = (
cad.max_query_len + self.net_num_new_slots_per_request
)
BLOCK_SIZE_TOKENS = min(
256, triton.next_power_of_2(max_num_tokens_per_request)
)
num_blocks = (
max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1
) // BLOCK_SIZE_TOKENS
total_num_input_tokens = target_token_ids.shape[0]
total_num_output_tokens = total_num_input_tokens + (
self.net_num_new_slots_per_request * batch_size
)
token_indices_to_sample = torch.empty(
batch_size * self.extra_slots_per_request,
dtype=torch.int32,
device=self.device,
)
# Destination indices to write target_hidden_states into drafting buffer.
out_hidden_state_mapping = torch.empty(
total_num_input_tokens, dtype=torch.int32, device=self.device
)
# Kernel grid: one program per request (row)
grid = (batch_size, num_blocks)
query_start_loc = cad.query_start_loc
query_end_loc = cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
query_end_loc = query_end_loc - num_rejected_tokens_gpu
copy_and_expand_eagle_inputs_kernel[grid](
# (Padded) Inputs from the target model
target_token_ids_ptr=target_token_ids,
target_positions_ptr=target_positions,
next_token_ids_ptr=next_token_ids, # sampled tokens, one per request
# Outputs to the drafting buffers
out_input_ids_ptr=self.input_ids,
out_positions_ptr=self.positions, # Doesn't support mrope for now
out_is_rejected_token_mask_ptr=self.is_rejected_token_mask,
out_is_masked_token_mask_ptr=self.is_masked_token_mask,
out_new_token_indices_ptr=token_indices_to_sample,
out_hidden_state_mapping_ptr=out_hidden_state_mapping,
# Input metadata
query_start_loc_ptr=query_start_loc,
query_end_loc_ptr=query_end_loc,
padding_token_id=0,
parallel_drafting_token_id=self.parallel_drafting_token_id,
# Sizing info
# Note that we can deduce batch_size for free from the grid size
total_input_tokens=total_num_input_tokens,
num_padding_slots_per_request=self.extra_slots_per_request,
shift_input_ids=self.pass_hidden_states_to_model,
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
)
if self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
self.hidden_states[out_hidden_state_mapping] = target_hidden_states
# Use torch.where to avoid DtoH sync from boolean indexing
mask = self.is_masked_token_mask[:total_num_output_tokens]
torch.where(
mask.unsqueeze(1),
self.parallel_drafting_hidden_state_tensor,
self.hidden_states[:total_num_output_tokens],
out=self.hidden_states[:total_num_output_tokens],
)
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
builder = (
self._get_attention_metadata_builder()
if self.attn_metadata_builder is None
else self.attn_metadata_builder
)
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[
:total_num_output_tokens
],
block_size=builder.kv_cache_spec.block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
# 3. Update the common attention metadata with the new (meta)data
new_cad = extend_all_queries_by_N(
cad,
N=self.net_num_new_slots_per_request,
arange=self.arange,
new_slot_mapping=new_slot_mapping,
)
return total_num_output_tokens, token_indices_to_sample, new_cad
def model_returns_tuple(self) -> bool: def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model") return self.method not in ("mtp", "draft_model")
...@@ -1081,8 +1271,21 @@ class SpecDecodeBaseProposer: ...@@ -1081,8 +1271,21 @@ class SpecDecodeBaseProposer:
model = model.module model = model.module
return model.__class__.__name__ return model.__class__.__name__
def _get_model(self) -> nn.Module:
"""
Default method to call get_model(). Can be overridden by subclasses which
need to customize model loading.
"""
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
model = get_model(
vllm_config=self.vllm_config,
model_config=self.speculative_config.draft_model_config,
)
return model
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
draft_model_config = self.speculative_config.draft_model_config
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config( get_layers_from_vllm_config(
self.vllm_config, self.vllm_config,
...@@ -1096,12 +1299,7 @@ class SpecDecodeBaseProposer: ...@@ -1096,12 +1299,7 @@ class SpecDecodeBaseProposer:
).keys() ).keys()
) )
from vllm.compilation.backends import set_model_tag self.model = self._get_model()
with set_model_tag("eagle_head"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=draft_model_config
)
draft_attn_layer_names = ( draft_attn_layer_names = (
get_layers_from_vllm_config( get_layers_from_vllm_config(
...@@ -1170,7 +1368,26 @@ class SpecDecodeBaseProposer: ...@@ -1170,7 +1368,26 @@ class SpecDecodeBaseProposer:
else: else:
target_language_model = target_model target_language_model = target_model
# share embed_tokens with the target model if needed self._maybe_share_embeddings(target_language_model)
self._maybe_share_lm_head(target_language_model)
if self.parallel_drafting and self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
self.parallel_drafting_hidden_state_tensor.copy_(
self.model.combine_hidden_states(
self.model.mask_hidden.view(3 * self.hidden_size)
)
if self.eagle3_use_aux_hidden_state
else self.model.mask_hidden.view(self.hidden_size)
)
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
"""
Some draft models may not have their own embedding layers, and some may
have a duplicate copy of the target model's embedding layers. In these cases,
we share the target model's embedding layers with the draft model to save
memory.
"""
if get_pp_group().world_size == 1: if get_pp_group().world_size == 1:
inner_model = getattr(target_language_model, "model", None) inner_model = getattr(target_language_model, "model", None)
if inner_model is None: if inner_model is None:
...@@ -1233,7 +1450,12 @@ class SpecDecodeBaseProposer: ...@@ -1233,7 +1450,12 @@ class SpecDecodeBaseProposer:
" from the target model." " from the target model."
) )
# share lm_head with the target model if needed def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
"""
Some draft models may not have their own LM head, and some may have a
duplicate copy of the target model's LM head. In these cases, we share
the target model's LM head with the draft model to save memory.
"""
share_lm_head = False share_lm_head = False
if hasattr(self.model, "has_own_lm_head"): if hasattr(self.model, "has_own_lm_head"):
# EAGLE model # EAGLE model
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig, replace
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
)
PADDING_SLOT_ID = -1
@triton.jit @triton.jit
...@@ -107,3 +115,243 @@ def eagle_prepare_next_token_padded_kernel( ...@@ -107,3 +115,243 @@ def eagle_prepare_next_token_padded_kernel(
tl.store(next_token_ids_ptr + req_idx, backup_token) tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count) tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
def compute_new_slot_mapping(
cad: CommonAttentionMetadata,
new_positions: torch.Tensor,
is_rejected_token_mask: torch.Tensor,
block_size: int,
num_new_tokens: int,
max_model_len: int,
):
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
req_indices = torch.repeat_interleave(
req_indices,
cad.naive_query_lens() + num_new_tokens,
output_size=len(new_positions),
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
block_table_indices = (
req_indices * n_blocks_per_req + clamped_positions // block_size
)
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
block_offsets = clamped_positions % block_size
new_slot_mapping = block_nums * block_size + block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len = new_positions >= max_model_len
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
return new_slot_mapping
def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the drafter.
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = replace(
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
)
new: VllmConfig = replace(
old,
quant_config=None,
parallel_config=new_parallel_config,
model_config=old_spec_config.draft_model_config,
)
return new
def extend_all_queries_by_N(
common_attn_metadata: CommonAttentionMetadata,
N: int,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by N.
Also all seq lens are increased by N.
This is useful e.g. in speculative decoding with parallel drafting, where we
extend each sequence by N tokens and predict all tokens in one pass.
The slot mapping is computed externally, as it requires more information.
"""
cad = common_attn_metadata
# query start loc must be increased by [+0, +N, +2N, ..., +batch_size * N]
new_query_start_loc = cad.query_start_loc + N * arange[: len(cad.query_start_loc)]
new_query_start_loc_cpu = cad.query_start_loc_cpu + N * torch.arange(
len(cad.query_start_loc_cpu), dtype=torch.int32
)
new_cad = cad.replace(
query_start_loc=new_query_start_loc,
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=cad.seq_lens + N,
# each request is extended by N tokens -> batch_size * N tokens are added
num_actual_tokens=cad.num_actual_tokens + cad.batch_size() * N,
# All query lens increase by N, so max query len increases by N
max_query_len=cad.max_query_len + N,
max_seq_len=cad.max_seq_len + N,
slot_mapping=new_slot_mapping,
)
return new_cad
# Unified copy/expand kernel
@triton.jit
def copy_and_expand_eagle_inputs_kernel(
# (Padded) Inputs from the target model
target_token_ids_ptr, # [total_tokens_in_batch]
target_positions_ptr, # [total_tokens_in_batch]
next_token_ids_ptr, # [num_reqs]
# Outputs to the drafting buffers
out_input_ids_ptr, # [total_draft_tokens_in_batch] (output)
out_positions_ptr, # [total_draft_tokens_in_batch] (output)
out_is_rejected_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_is_masked_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_new_token_indices_ptr, # [num_padding_slots_per_request * num_reqs] (output)
out_hidden_state_mapping_ptr, # [total_tokens_in_batch]
# Input metadata
query_start_loc_ptr, # [num_reqs + 1], last value is the total num input tokens
query_end_loc_ptr, # [num_reqs]
padding_token_id, # tl.int32
parallel_drafting_token_id, # tl.int32
# Sizing info
total_input_tokens, # tl.int32
num_padding_slots_per_request, # tl.int32
shift_input_ids, # tl.bool
BLOCK_SIZE_TOKENS: tl.constexpr, # Blocks along token dim to handle prefills
):
"""
Copy and expand inputs from the target model to the drafting buffers for Eagle
speculative decoding. This kernel handles padding slots and parallel drafting
tokens, if enabled.
"""
request_idx = tl.program_id(axis=0)
token_batch_idx = tl.program_id(axis=1)
# Load query locations
query_start_loc = tl.load(query_start_loc_ptr + request_idx)
next_query_start_loc = tl.load(query_start_loc_ptr + request_idx + 1)
query_end_loc = tl.load(query_end_loc_ptr + request_idx)
# Calculate number of valid tokens to copy and input offset
# With shift_input_ids=True, we skip the first token
# Output layout: each request gets (input_len + num_padding_slots_per_request) slots
# But with shift, we lose one token per request
if shift_input_ids:
num_valid_tokens = query_end_loc - query_start_loc
input_offset = 1
output_start = query_start_loc + request_idx * (
num_padding_slots_per_request - 1
)
else:
num_valid_tokens = query_end_loc - query_start_loc + 1
input_offset = 0
output_start = query_start_loc + request_idx * num_padding_slots_per_request
# Number of rejected tokens from previous speculation
num_rejected = next_query_start_loc - query_end_loc - 1
# Total output tokens for this request
total_output_tokens = (
num_valid_tokens + num_padding_slots_per_request + num_rejected
)
# Process tokens in this block
j = token_batch_idx * BLOCK_SIZE_TOKENS + tl.arange(0, BLOCK_SIZE_TOKENS)
# Compute masks for different output regions:
# [0, num_valid_tokens): valid tokens copied from input
# [num_valid_tokens]: bonus token from next_token_ids
# (num_valid_tokens, num_valid_tokens + num_padding_slots_per_request):
# parallel drafting slots
# [num_valid_tokens + num_padding_slots_per_request, total_output_tokens):
# rejected slots
in_bounds = j < total_output_tokens
is_valid_region = j < num_valid_tokens
is_bonus_region = j == num_valid_tokens
is_parallel_draft_region = (j > num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
is_rejected_region = j >= num_valid_tokens + num_padding_slots_per_request
# Compute output indices
out_idx = output_start + j
# For valid tokens, compute input index
in_idx = query_start_loc + input_offset + j
# Clamp to avoid out-of-bounds access (masked loads still need valid addresses)
in_idx_clamped = tl.minimum(in_idx, total_input_tokens - 1)
# Load input tokens (masked to valid region)
token_ids = tl.load(
target_token_ids_ptr + in_idx_clamped, mask=is_valid_region & in_bounds, other=0
)
# Load the starting position for this request (first position in the sequence)
start_pos = tl.load(target_positions_ptr + query_start_loc)
# Load bonus token for this request
bonus_token = tl.load(next_token_ids_ptr + request_idx)
# Build final token_ids based on region
token_ids = tl.where(is_bonus_region, bonus_token, token_ids)
token_ids = tl.where(
is_parallel_draft_region, parallel_drafting_token_id, token_ids
)
token_ids = tl.where(is_rejected_region, padding_token_id, token_ids)
# Build final positions:
# Positions are NOT shifted - they start from the first input position and increment
# Output position j gets start_pos + j
# (e.g., input positions [5,6,7] -> output [5,6,7,8,9,...])
positions = start_pos + j
# Rejected positions are don't-care, set to 0
positions = tl.where(is_rejected_region, 0, positions)
# Compute output masks
is_rejected_out = is_rejected_region & in_bounds
is_masked_out = is_parallel_draft_region & in_bounds
# Compute indices of new tokens (bonus + parallel drafting) for sampling
# New tokens are at positions
# [num_valid_tokens, num_valid_tokens + num_padding_slots_per_request)
is_new_token_region = (j >= num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
new_token_local_idx = (
j - num_valid_tokens
) # 0 for bonus, 1, 2, ... for parallel drafting
new_token_out_idx = (
request_idx * num_padding_slots_per_request + new_token_local_idx
)
# Compute hidden state mapping (source index -> destination index)
# This maps each input position to its corresponding output position
# Hidden states don't get shifted, so we map all input tokens (including rejected)
if shift_input_ids:
num_input_tokens_this_request = next_query_start_loc - query_start_loc
is_input_region = j < num_input_tokens_this_request
src_idx = query_start_loc + j
tl.store(out_hidden_state_mapping_ptr + src_idx, out_idx, mask=is_input_region)
# Store outputs
tl.store(out_input_ids_ptr + out_idx, token_ids, mask=in_bounds)
tl.store(out_positions_ptr + out_idx, positions, mask=in_bounds)
tl.store(out_is_rejected_token_mask_ptr + out_idx, is_rejected_out, mask=in_bounds)
tl.store(out_is_masked_token_mask_ptr + out_idx, is_masked_out, mask=in_bounds)
tl.store(
out_new_token_indices_ptr + new_token_out_idx,
out_idx,
mask=is_new_token_region & in_bounds,
)
...@@ -4090,7 +4090,7 @@ class GPUModelRunner( ...@@ -4090,7 +4090,7 @@ class GPUModelRunner(
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample, token_indices_to_sample=token_indices_to_sample,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs, mm_embed_inputs=mm_embed_inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment