Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4aaaf8c8
Unverified
Commit
4aaaf8c8
authored
Mar 10, 2026
by
Sladyn
Committed by
GitHub
Mar 11, 2026
Browse files
feat(spec_decode): fuse EAGLE step slot mapping and metadata updates (#33503)
Signed-off-by:
sladynnunes
<
snunes@usc.edu
>
parent
4bf53362
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
318 additions
and
59 deletions
+318
-59
tests/v1/spec_decode/test_eagle_step_kernel.py
tests/v1/spec_decode/test_eagle_step_kernel.py
+175
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+35
-59
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+108
-0
No files found.
tests/v1/spec_decode/test_eagle_step_kernel.py
0 → 100644
View file @
4aaaf8c8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for the fused EAGLE slot mapping kernel."""
import
pytest
import
torch
from
vllm.v1.spec_decode.utils
import
(
PADDING_SLOT_ID
,
eagle_step_update_slot_mapping_and_metadata
,
)
# Skip if no CUDA - Triton kernel requires GPU
pytest
.
importorskip
(
"triton"
)
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA required for EAGLE kernel tests"
,
allow_module_level
=
True
)
def
_reference_eagle_step_slot_mapping
(
positions_1d
:
torch
.
Tensor
,
block_table_tensor
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_model_len
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Python reference for eagle_step_update_slot_mapping_and_metadata."""
new_positions
=
positions_1d
+
1
exceeds_max
=
new_positions
>=
max_model_len
clamped_positions
=
torch
.
where
(
exceeds_max
,
torch
.
zeros_like
(
positions_1d
),
new_positions
)
block_numbers
=
(
clamped_positions
//
block_size
).
clamp
(
max
=
block_table_tensor
.
shape
[
1
]
-
1
)
block_ids
=
block_table_tensor
[
torch
.
arange
(
positions_1d
.
shape
[
0
],
device
=
positions_1d
.
device
),
block_numbers
.
long
(),
].
long
()
slot_mapping
=
block_ids
*
block_size
+
(
clamped_positions
%
block_size
)
slot_mapping
=
torch
.
where
(
exceeds_max
,
torch
.
full_like
(
slot_mapping
,
PADDING_SLOT_ID
),
slot_mapping
)
new_seq_lens
=
torch
.
where
(
exceeds_max
,
torch
.
ones_like
(
seq_lens
),
seq_lens
+
1
)
new_seq_lens
=
new_seq_lens
.
clamp
(
max
=
max_model_len
)
return
clamped_positions
,
slot_mapping
,
new_seq_lens
def
test_eagle_step_slot_mapping_kernel
():
"""Test fused kernel matches Python reference for slot mapping and metadata."""
device
=
torch
.
device
(
"cuda"
)
batch_size
=
32
block_size
=
16
max_model_len
=
4096
n_blocks_per_req
=
(
max_model_len
+
block_size
-
1
)
//
block_size
positions_1d
=
torch
.
randint
(
0
,
max_model_len
-
10
,
(
batch_size
,),
dtype
=
torch
.
int64
,
device
=
device
)
block_table_tensor
=
torch
.
randint
(
0
,
1000
,
(
batch_size
,
n_blocks_per_req
),
dtype
=
torch
.
int32
,
device
=
device
)
seq_lens
=
torch
.
randint
(
1
,
100
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
ref_clamped
,
ref_slot
,
ref_seq_lens
=
_reference_eagle_step_slot_mapping
(
positions_1d
.
clone
(),
block_table_tensor
,
seq_lens
.
clone
(),
block_size
,
max_model_len
,
)
out_clamped
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
device
)
out_slot
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
device
)
seq_lens_copy
=
seq_lens
.
clone
()
eagle_step_update_slot_mapping_and_metadata
(
positions_1d
=
positions_1d
,
block_table_tensor
=
block_table_tensor
,
seq_lens
=
seq_lens_copy
,
block_size
=
block_size
,
max_model_len
=
max_model_len
,
out_clamped_positions
=
out_clamped
,
out_slot_mapping
=
out_slot
,
)
assert
torch
.
equal
(
out_clamped
,
ref_clamped
),
(
f
"clamped:
{
out_clamped
}
vs
{
ref_clamped
}
"
)
assert
torch
.
equal
(
out_slot
,
ref_slot
),
f
"slot:
{
out_slot
}
vs
{
ref_slot
}
"
assert
torch
.
equal
(
seq_lens_copy
,
ref_seq_lens
),
(
f
"seq_lens:
{
seq_lens_copy
}
vs
{
ref_seq_lens
}
"
)
def
test_eagle_step_slot_mapping_kernel_exceeds_max
():
"""Test fused kernel when position exceeds max_model_len."""
device
=
torch
.
device
(
"cuda"
)
batch_size
=
4
block_size
=
16
max_model_len
=
100
n_blocks_per_req
=
(
max_model_len
+
block_size
-
1
)
//
block_size
positions_1d
=
torch
.
tensor
([
50
,
98
,
99
,
100
],
dtype
=
torch
.
int64
,
device
=
device
)
block_table_tensor
=
torch
.
randint
(
0
,
100
,
(
batch_size
,
n_blocks_per_req
),
dtype
=
torch
.
int32
,
device
=
device
)
seq_lens
=
torch
.
tensor
([
51
,
99
,
100
,
101
],
dtype
=
torch
.
int32
,
device
=
device
)
out_clamped
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
device
)
out_slot
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
device
)
eagle_step_update_slot_mapping_and_metadata
(
positions_1d
=
positions_1d
,
block_table_tensor
=
block_table_tensor
,
seq_lens
=
seq_lens
,
block_size
=
block_size
,
max_model_len
=
max_model_len
,
out_clamped_positions
=
out_clamped
,
out_slot_mapping
=
out_slot
,
)
assert
out_clamped
[
0
].
item
()
==
51
assert
out_clamped
[
1
].
item
()
==
99
assert
out_clamped
[
2
].
item
()
==
0
assert
out_clamped
[
3
].
item
()
==
0
assert
out_slot
[
2
].
item
()
==
PADDING_SLOT_ID
assert
out_slot
[
3
].
item
()
==
PADDING_SLOT_ID
assert
seq_lens
[
2
].
item
()
==
1
assert
seq_lens
[
3
].
item
()
==
1
def
test_eagle_step_slot_mapping_kernel_cudagraph_padding
():
"""Test that padding threads write PADDING_SLOT_ID when
input_batch_size > batch_size (cudagraph padding)."""
device
=
torch
.
device
(
"cuda"
)
batch_size
=
4
input_batch_size
=
8
block_size
=
16
max_model_len
=
4096
n_blocks_per_req
=
(
max_model_len
+
block_size
-
1
)
//
block_size
positions_1d
=
torch
.
tensor
([
10
,
20
,
30
,
40
],
dtype
=
torch
.
int64
,
device
=
device
)
block_table_tensor
=
torch
.
randint
(
0
,
100
,
(
batch_size
,
n_blocks_per_req
),
dtype
=
torch
.
int32
,
device
=
device
)
seq_lens
=
torch
.
tensor
([
11
,
21
,
31
,
41
],
dtype
=
torch
.
int32
,
device
=
device
)
ref_clamped
,
ref_slot
,
ref_seq_lens
=
_reference_eagle_step_slot_mapping
(
positions_1d
.
clone
(),
block_table_tensor
,
seq_lens
.
clone
(),
block_size
,
max_model_len
,
)
out_clamped
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
device
)
out_slot
=
torch
.
full
((
input_batch_size
,),
-
999
,
dtype
=
torch
.
int64
,
device
=
device
)
seq_lens_copy
=
seq_lens
.
clone
()
eagle_step_update_slot_mapping_and_metadata
(
positions_1d
=
positions_1d
,
block_table_tensor
=
block_table_tensor
,
seq_lens
=
seq_lens_copy
,
block_size
=
block_size
,
max_model_len
=
max_model_len
,
out_clamped_positions
=
out_clamped
,
out_slot_mapping
=
out_slot
,
input_batch_size
=
input_batch_size
,
)
# Real slots should match the reference
assert
torch
.
equal
(
out_clamped
,
ref_clamped
)
assert
torch
.
equal
(
out_slot
[:
batch_size
],
ref_slot
)
assert
torch
.
equal
(
seq_lens_copy
,
ref_seq_lens
)
# Padding slots should be PADDING_SLOT_ID
for
i
in
range
(
batch_size
,
input_batch_size
):
assert
out_slot
[
i
].
item
()
==
PADDING_SLOT_ID
vllm/v1/spec_decode/eagle.py
View file @
4aaaf8c8
...
@@ -44,6 +44,7 @@ from vllm.v1.spec_decode.utils import (
...
@@ -44,6 +44,7 @@ from vllm.v1.spec_decode.utils import (
copy_and_expand_eagle_inputs_kernel
,
copy_and_expand_eagle_inputs_kernel
,
eagle_prepare_inputs_padded_kernel
,
eagle_prepare_inputs_padded_kernel
,
eagle_prepare_next_token_padded_kernel
,
eagle_prepare_next_token_padded_kernel
,
eagle_step_update_slot_mapping_and_metadata
,
extend_all_queries_by_N
,
extend_all_queries_by_N
,
)
)
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
...
@@ -533,41 +534,46 @@ class SpecDecodeBaseProposer:
...
@@ -533,41 +534,46 @@ class SpecDecodeBaseProposer:
common_attn_metadata
.
_seq_lens_cpu
=
None
common_attn_metadata
.
_seq_lens_cpu
=
None
common_attn_metadata
.
_num_computed_tokens_cpu
=
None
common_attn_metadata
.
_num_computed_tokens_cpu
=
None
block_size
=
self
.
block_size
assert
block_size
>
0
,
"block_size has not been initialized."
for
token_index
in
range
(
self
.
num_speculative_tokens
-
1
):
for
token_index
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
# tensor.argmax() returns int64 by default.
input_ids
=
draft_token_ids_list
[
-
1
].
int
()
input_ids
=
draft_token_ids_list
[
-
1
].
int
()
# Use fused kernel for slot mapping and metadata updates.
# Write clamped positions directly into the positions buffer to
# avoid an extra D2D copy for the common (non-mrope) case.
positions_1d
=
positions
[
0
]
if
self
.
uses_mrope
else
positions
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
positions
+=
1
out_pos
=
self
.
mrope_positions
[
0
,
:
batch_size
]
# NOTE(woosuk): We should handle the case where the draft model
elif
self
.
uses_xdrope_dim
>
0
and
self
.
draft_uses_xdrope_dim
>
0
:
# generates tokens beyond the max model length.
out_pos
=
self
.
xdrope_positions
[
0
,
:
batch_size
]
# Since it is complex to remove such requests from the batch,
else
:
# we keep them in the batch but adjust the position ids
out_pos
=
self
.
positions
[:
batch_size
]
# and slot mappings to avoid the
eagle_step_update_slot_mapping_and_metadata
(
# out-of-range access during the model execution.
positions_1d
=
positions_1d
,
# The draft tokens generated with this adjustment
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
# should be ignored.
seq_lens
=
common_attn_metadata
.
seq_lens
,
exceeds_max_model_len
=
positions
[
0
]
>=
self
.
max_model_len
block_size
=
block_size
,
# Mask out the position ids that exceed the max model length.
max_model_len
=
self
.
max_model_len
,
# Otherwise, we may get out-of-range error in RoPE.
out_clamped_positions
=
out_pos
,
clamped_positions
=
torch
.
where
(
out_slot_mapping
=
self
.
_slot_mapping_buffer
[:
input_batch_size
],
exceeds_max_model_len
.
unsqueeze
(
0
),
input_batch_size
=
input_batch_size
,
torch
.
zeros_like
(
positions
),
positions
,
)
)
common_attn_metadata
.
slot_mapping
=
self
.
_slot_mapping_buffer
[:
batch_size
]
if
self
.
uses_mrope
:
self
.
mrope_positions
[
1
:,
:
batch_size
]
=
self
.
mrope_positions
[
0
,
:
batch_size
]
positions
=
self
.
mrope_positions
[:,
:
batch_size
]
elif
self
.
uses_xdrope_dim
>
0
and
self
.
draft_uses_xdrope_dim
>
0
:
self
.
xdrope_positions
[
1
:,
:
batch_size
]
=
self
.
xdrope_positions
[
0
,
:
batch_size
]
positions
=
self
.
xdrope_positions
[
0
,
:
batch_size
]
else
:
else
:
positions
+=
1
positions
=
self
.
positions
[:
batch_size
]
exceeds_max_model_len
=
positions
>=
self
.
max_model_len
clamped_positions
=
torch
.
where
(
exceeds_max_model_len
,
0
,
positions
)
# For data integrity when async scheduling, we shouldn't use in place
# operations in case they are modified in next step's `prepare_input`
# of main model.
# Increment the sequence lengths.
common_attn_metadata
.
seq_lens
+=
1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
# Increment the maximum sequence length. We increment max_seq_len
# Increment the maximum sequence length. We increment max_seq_len
# unconditionally even though some seq_lens may have been capped above,
# unconditionally even though some seq_lens may have been capped above,
# as max_seq_len serves as an upper bound for sequence lengths.
# as max_seq_len serves as an upper bound for sequence lengths.
...
@@ -582,33 +588,6 @@ class SpecDecodeBaseProposer:
...
@@ -582,33 +588,6 @@ class SpecDecodeBaseProposer:
if
common_attn_metadata
.
_num_computed_tokens_cpu
is
not
None
:
if
common_attn_metadata
.
_num_computed_tokens_cpu
is
not
None
:
common_attn_metadata
.
_num_computed_tokens_cpu
+=
1
common_attn_metadata
.
_num_computed_tokens_cpu
+=
1
# Compute the slot mapping.
block_size
=
self
.
block_size
assert
block_size
>
0
,
"block_size has not been initialized."
if
self
.
uses_mrope
:
# all dimensions of positions are the same
block_numbers
=
clamped_positions
[
0
]
//
block_size
else
:
block_numbers
=
clamped_positions
//
block_size
block_ids
=
common_attn_metadata
.
block_table_tensor
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
)
)
block_ids
=
block_ids
.
view
(
-
1
)
if
self
.
uses_mrope
:
common_attn_metadata
.
slot_mapping
=
(
block_ids
*
block_size
+
clamped_positions
[
0
]
%
block_size
)
else
:
common_attn_metadata
.
slot_mapping
=
(
block_ids
*
block_size
+
clamped_positions
%
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
common_attn_metadata
.
slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# Rebuild attention metadata
# Rebuild attention metadata
for
attn_group
in
self
.
draft_attn_groups
:
for
attn_group
in
self
.
draft_attn_groups
:
attn_metadata
=
attn_group
.
get_metadata_builder
().
build_for_drafting
(
attn_metadata
=
attn_group
.
get_metadata_builder
().
build_for_drafting
(
...
@@ -620,7 +599,6 @@ class SpecDecodeBaseProposer:
...
@@ -620,7 +599,6 @@ class SpecDecodeBaseProposer:
# copy inputs to buffer for cudagraph
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
_set_positions
(
batch_size
,
clamped_positions
)
self
.
hidden_states
[:
batch_size
]
=
hidden_states
self
.
hidden_states
[:
batch_size
]
=
hidden_states
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
self
.
inputs_embeds
[:
batch_size
]
=
self
.
model
.
embed_input_ids
(
input_ids
)
self
.
inputs_embeds
[:
batch_size
]
=
self
.
model
.
embed_input_ids
(
input_ids
)
...
@@ -646,9 +624,7 @@ class SpecDecodeBaseProposer:
...
@@ -646,9 +624,7 @@ class SpecDecodeBaseProposer:
num_tokens
=
input_batch_size
,
num_tokens
=
input_batch_size
,
num_tokens_across_dp
=
batch_size_across_dp
,
num_tokens_across_dp
=
batch_size_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
self
.
_get_slot_mapping
(
slot_mapping
=
self
.
_get_slot_mapping
(
input_batch_size
),
input_batch_size
,
common_attn_metadata
.
slot_mapping
),
):
):
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
if
not
self
.
model_returns_tuple
():
if
not
self
.
model_returns_tuple
():
...
...
vllm/v1/spec_decode/utils.py
View file @
4aaaf8c8
...
@@ -11,6 +11,114 @@ from vllm.v1.attention.backends.utils import (
...
@@ -11,6 +11,114 @@ from vllm.v1.attention.backends.utils import (
PADDING_SLOT_ID
=
-
1
PADDING_SLOT_ID
=
-
1
@
triton
.
jit
def
eagle_step_slot_mapping_metadata_kernel
(
positions_ptr
,
# [batch_size] - current positions (1D view for M-RoPE)
block_table_ptr
,
# [batch_size, n_blocks_per_req]
block_table_stride
,
# stride for block_table dim 1
seq_lens_ptr
,
# [batch_size] - read and write
out_clamped_positions_ptr
,
# [batch_size] (output)
out_slot_mapping_ptr
,
# [input_batch_size] (output)
block_size
:
tl
.
constexpr
,
max_model_len
:
tl
.
constexpr
,
n_blocks_per_req
:
tl
.
constexpr
,
PAD_ID
:
tl
.
constexpr
,
batch_size
,
):
"""
Fused kernel for EAGLE autoregressive step: updates positions, slot mapping,
and sequence lengths in a single kernel to reduce launch overhead.
Launched with input_batch_size threads. Threads with req_idx >= batch_size
are cudagraph padding slots and only write PADDING_SLOT_ID.
Each real thread handles one request in the batch. Computes:
- new_position = position + 1, clamped if exceeds max_model_len
- slot_mapping from block table lookup
- seq_lens += 1, or 1 if position exceeds max
"""
req_idx
=
tl
.
program_id
(
0
)
if
req_idx
>=
batch_size
:
tl
.
store
(
out_slot_mapping_ptr
+
req_idx
,
PAD_ID
)
return
# Load current position and increment
position
=
tl
.
load
(
positions_ptr
+
req_idx
)
new_position
=
position
+
1
# Check bounds and compute clamped position
exceeds_max
=
new_position
>=
max_model_len
clamped_position
=
tl
.
where
(
exceeds_max
,
0
,
new_position
)
# Block table lookup: block_number = position // block_size
# Clamp block_number to avoid OOB when position is at max
block_number
=
clamped_position
//
block_size
block_number
=
tl
.
minimum
(
block_number
,
n_blocks_per_req
-
1
)
block_id
=
tl
.
load
(
block_table_ptr
+
req_idx
*
block_table_stride
+
block_number
)
slot_id
=
block_id
*
block_size
+
(
clamped_position
%
block_size
)
slot_id
=
tl
.
where
(
exceeds_max
,
PAD_ID
,
slot_id
)
# Update seq_lens: +1 normally, or 1 if exceeded
seq_len
=
tl
.
load
(
seq_lens_ptr
+
req_idx
)
new_seq_len
=
tl
.
where
(
exceeds_max
,
1
,
seq_len
+
1
)
new_seq_len
=
tl
.
minimum
(
new_seq_len
,
max_model_len
)
# Store outputs
tl
.
store
(
out_clamped_positions_ptr
+
req_idx
,
clamped_position
)
tl
.
store
(
out_slot_mapping_ptr
+
req_idx
,
slot_id
)
tl
.
store
(
seq_lens_ptr
+
req_idx
,
new_seq_len
)
def
eagle_step_update_slot_mapping_and_metadata
(
positions_1d
:
torch
.
Tensor
,
block_table_tensor
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_model_len
:
int
,
out_clamped_positions
:
torch
.
Tensor
,
out_slot_mapping
:
torch
.
Tensor
,
input_batch_size
:
int
|
None
=
None
,
)
->
None
:
"""
Fused update of slot mapping and metadata for one EAGLE autoregressive step.
Updates seq_lens in place. Writes to out_clamped_positions and out_slot_mapping.
When input_batch_size > batch_size, threads beyond batch_size write
PADDING_SLOT_ID to out_slot_mapping for cudagraph padding.
Args:
positions_1d: [batch_size] current positions (use positions[0] for M-RoPE)
block_table_tensor: [batch_size, n_blocks_per_req]
seq_lens: [batch_size] updated in place
block_size: KV cache block size
max_model_len: max model length for clamping
out_clamped_positions: [batch_size] output buffer for clamped positions
out_slot_mapping: [input_batch_size] output buffer for slot mapping
input_batch_size: total batch size including cudagraph padding;
defaults to batch_size (no padding)
"""
batch_size
=
positions_1d
.
shape
[
0
]
if
input_batch_size
is
None
:
input_batch_size
=
batch_size
n_blocks_per_req
=
block_table_tensor
.
shape
[
1
]
eagle_step_slot_mapping_metadata_kernel
[(
input_batch_size
,)](
positions_1d
,
block_table_tensor
,
block_table_tensor
.
stride
(
0
),
seq_lens
,
out_clamped_positions
,
out_slot_mapping
,
block_size
=
block_size
,
max_model_len
=
max_model_len
,
n_blocks_per_req
=
n_blocks_per_req
,
PAD_ID
=
PADDING_SLOT_ID
,
batch_size
=
batch_size
,
)
@
triton
.
jit
@
triton
.
jit
def
eagle_prepare_inputs_padded_kernel
(
def
eagle_prepare_inputs_padded_kernel
(
cu_num_draft_tokens_ptr
,
# [num_reqs]
cu_num_draft_tokens_ptr
,
# [num_reqs]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment