Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ccf90ba7
Unverified
Commit
ccf90ba7
authored
Apr 13, 2026
by
Giancarlo Delfin
Committed by
GitHub
Apr 13, 2026
Browse files
[Model Runner V2] Add full cuda graph support for eagle prefill (#37588)
Signed-off-by:
Giancarlo Delfin
<
gdelfin@inferact.ai
>
parent
6adacfcb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
210 additions
and
134 deletions
+210
-134
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+0
-6
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
+6
-16
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
+204
-112
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
ccf90ba7
...
...
@@ -464,7 +464,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mappings_by_layer
=
self
.
execute_model_state
.
slot_mappings_by_layer
hidden_states
=
self
.
execute_model_state
.
hidden_states
aux_hidden_states
=
self
.
execute_model_state
.
aux_hidden_states
num_tokens_across_dp
=
self
.
execute_model_state
.
num_tokens_across_dp
self
.
execute_model_state
=
None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
...
...
@@ -496,7 +495,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_prefill_tokens
=
self
.
req_states
.
next_prefill_tokens
,
temperature
=
self
.
sampler
.
sampling_states
.
temperature
.
gpu
,
seeds
=
self
.
sampler
.
sampling_states
.
seeds
.
gpu
,
num_tokens_across_dp
=
num_tokens_across_dp
,
dummy_run
=
True
,
skip_attn_for_dummy_run
=
skip_attn
,
mm_inputs
=
mm_inputs
,
...
...
@@ -1110,7 +1108,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
=
hidden_states
,
aux_hidden_states
=
aux_hidden_states
,
kv_connector_output
=
kv_connector_output
,
num_tokens_across_dp
=
num_tokens_across_dp
,
)
if
not
self
.
is_last_pp_rank
:
...
...
@@ -1135,7 +1132,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
=
self
.
execute_model_state
.
hidden_states
aux_hidden_states
=
self
.
execute_model_state
.
aux_hidden_states
kv_connector_output
=
self
.
execute_model_state
.
kv_connector_output
num_tokens_across_dp
=
self
.
execute_model_state
.
num_tokens_across_dp
self
.
execute_model_state
=
None
if
not
self
.
is_last_pp_rank
:
...
...
@@ -1228,7 +1224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
req_states
.
next_prefill_tokens
,
self
.
sampler
.
sampling_states
.
temperature
.
gpu
,
self
.
sampler
.
sampling_states
.
seeds
.
gpu
,
num_tokens_across_dp
=
num_tokens_across_dp
,
mm_inputs
=
mm_inputs
,
)
self
.
req_states
.
draft_tokens
[
input_batch
.
idx_mapping
]
=
draft_tokens
...
...
@@ -1336,4 +1331,3 @@ class ExecuteModelState(NamedTuple):
hidden_states
:
torch
.
Tensor
|
None
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
kv_connector_output
:
KVConnectorOutput
|
None
num_tokens_across_dp
:
torch
.
Tensor
|
None
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
View file @
ccf90ba7
...
...
@@ -19,21 +19,16 @@ from vllm.v1.worker.utils import AttentionGroup
class
EagleCudaGraphManager
(
CudaGraphManager
):
"""CudaGraphManager for Eagle speculative decoding
(FULL mode only)
."""
"""CudaGraphManager for Eagle speculative decoding."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
cudagraph_mode
:
CUDAGraphMode
,
d
raft_tok
en
s
:
torch
.
Tensor
,
d
ecode_query_l
en
:
int
,
):
assert
not
cudagraph_mode
.
has_mode
(
CUDAGraphMode
.
PIECEWISE
),
(
"EagleCudaGraphManager does not support PIECEWISE mode yet"
)
# Eagle always uses uniform decode with query_len=1
super
().
__init__
(
vllm_config
,
device
,
cudagraph_mode
,
decode_query_len
=
1
)
self
.
draft_tokens
=
draft_tokens
super
().
__init__
(
vllm_config
,
device
,
cudagraph_mode
,
decode_query_len
)
# Use a dedicated pool for Eagle to avoid memory overlap with the main
# model's cudagraph. The base class uses a shared global pool, but Eagle's
...
...
@@ -44,7 +39,7 @@ class EagleCudaGraphManager(CudaGraphManager):
def
capture
(
self
,
generate
_fn
:
Callable
,
forward
_fn
:
Callable
,
model_state
:
ModelState
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
...
...
@@ -52,7 +47,7 @@ class EagleCudaGraphManager(CudaGraphManager):
kv_cache_config
:
KVCacheConfig
,
progress_bar_desc
:
str
=
"Capturing CUDA graphs"
,
)
->
None
:
"""Capture CUDA graphs for Eagle
speculative decoding (FULL mode only)
."""
"""Capture CUDA graphs for Eagle."""
def
create_forward_fn
(
desc
:
BatchExecutionDescriptor
,
...
...
@@ -74,7 +69,7 @@ class EagleCudaGraphManager(CudaGraphManager):
kv_cache_config
,
)
return
lambda
cg_mode
:
generate
_fn
(
return
lambda
cg_mode
:
forward
_fn
(
num_reqs
,
num_tokens
,
attn_metadata
,
...
...
@@ -84,8 +79,3 @@ class EagleCudaGraphManager(CudaGraphManager):
)
super
().
capture
(
create_forward_fn
,
progress_bar_desc
)
def
run_fullgraph
(
self
,
desc
:
BatchExecutionDescriptor
)
->
torch
.
Tensor
:
"""Replay a captured FULL cudagraph and return draft tokens."""
super
().
run_fullgraph
(
desc
)
return
self
.
draft_tokens
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
View file @
ccf90ba7
...
...
@@ -19,11 +19,16 @@ from vllm.v1.worker.gpu.attn_utils import (
init_attn_backend
,
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
get_uniform_token_count
,
)
from
vllm.v1.worker.gpu.dp_utils
import
dispatch_cg_and_sync_dp
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.model_states.interface
import
ModelState
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.spec_decode.eagle.cudagraph
import
EagleCudaGraphManager
from
vllm.v1.worker.gpu.spec_decode.eagle.cudagraph
import
(
EagleCudaGraphManager
,
)
from
vllm.v1.worker.gpu.spec_decode.eagle.utils
import
load_eagle_model
logger
=
init_logger
(
__name__
)
...
...
@@ -76,6 +81,9 @@ class EagleSpeculator:
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
last_token_indices
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
supports_mm_inputs
=
MULTIMODAL_REGISTRY
.
supports_multimodal_inputs
(
self
.
draft_model_config
...
...
@@ -95,20 +103,30 @@ class EagleSpeculator:
device
=
device
,
)
self
.
cudagraph_manager
:
EagleCudaGraphManager
|
None
=
None
self
.
prefill_cudagraph_manager
:
EagleCudaGraphManager
|
None
=
None
self
.
decode_cudagraph_manager
:
EagleCudaGraphManager
|
None
=
None
def
init_cudagraph_manager
(
self
,
cudagraph_mode
:
CUDAGraphMode
)
->
None
:
if
cudagraph_mode
.
decode_mode
()
==
CUDAGraphMode
.
FULL
:
cudagraph_mode
=
CUDAGraphMode
.
FULL_DECODE_ONLY
else
:
cudagraph_mode
=
CUDAGraphMode
.
NONE
self
.
cudagraph_manager
=
EagleCudaGraphManager
(
cudagraph_mode
=
self
.
vllm_config
.
compilation_config
.
cudagraph_mode
# Initialize cudagraph manager for draft prefill (draft position 0).
self
.
prefill_cudagraph_manager
=
EagleCudaGraphManager
(
self
.
vllm_config
,
self
.
device
,
cudagraph_mode
,
self
.
draft_tokens
,
self
.
num_speculative_steps
+
1
,
)
# Initialize cudagraph manager for draft generation (draft positions > 0).
self
.
decode_cudagraph_manager
=
EagleCudaGraphManager
(
self
.
vllm_config
,
self
.
device
,
# Only use FULL graph mode, if available, because draft decodes
# only consist of a single token.
cudagraph_mode
.
decode_mode
(),
decode_query_len
=
1
,
)
# Share a single pool between prefill and decode since they never
# execute concurrently.
self
.
decode_cudagraph_manager
.
pool
=
self
.
prefill_cudagraph_manager
.
pool
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
target_attn_layer_names
=
get_layers_from_vllm_config
(
...
...
@@ -189,6 +207,47 @@ class EagleSpeculator:
last_hidden_states
,
hidden_states
=
ret_hidden_states
return
last_hidden_states
,
hidden_states
def
prefill
(
self
,
num_reqs
:
int
,
num_tokens
:
int
,
attn_metadata
:
dict
[
str
,
Any
]
|
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
,
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
cudagraph_runtime_mode
:
CUDAGraphMode
=
CUDAGraphMode
.
NONE
,
mm_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
)
->
None
:
last_token_indices
=
self
.
last_token_indices
[:
num_reqs
]
pos
=
self
.
input_buffers
.
positions
[
last_token_indices
]
idx_mapping
=
self
.
idx_mapping
[:
num_reqs
]
last_hidden_states
,
hidden_states
=
self
.
run_model
(
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
mm_inputs
=
mm_inputs
,
)
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
self
.
draft_tokens
[:
num_reqs
,
0
]
=
gumbel_sample
(
logits
,
idx_mapping
,
self
.
temperature
,
self
.
seeds
,
pos
+
1
,
apply_temperature
=
True
,
processed_logits_out
=
self
.
draft_logits
[:,
0
]
if
self
.
draft_logits
is
not
None
else
None
,
)
self
.
hidden_states
[:
num_reqs
]
=
hidden_states
[
last_token_indices
]
self
.
input_buffers
.
positions
[:
num_reqs
]
=
pos
def
generate_draft
(
self
,
num_reqs
:
int
,
...
...
@@ -281,19 +340,46 @@ class EagleSpeculator:
return
attn_metadata
def
capture_model
(
self
)
->
None
:
assert
self
.
cudagraph_manager
is
not
None
logger
.
info
(
"Capturing model for Eagle speculator..."
)
# Reset indices to zeros to prevent stale values from prior
# dummy runs to cause out-of-bounds indexing during capture.
self
.
last_token_indices
.
zero_
()
# Capture the prefill routine (model forward + compute_logits +
# gumbel_sample).
# For FULL graphs, the entire routine is recorded as one graph.
# For PIECEWISE, only the model's compiled regions are captured
# and the rest (compute_logits, gumbel_sample) runs eagerly.
assert
self
.
prefill_cudagraph_manager
is
not
None
self
.
prefill_cudagraph_manager
.
capture
(
self
.
prefill
,
self
.
model_state
,
self
.
input_buffers
,
self
.
block_tables
,
self
.
attn_groups
,
self
.
kv_cache_config
,
progress_bar_desc
=
"Capturing eagle prefill CUDA graphs"
,
)
if
self
.
num_speculative_steps
==
1
:
return
logger
.
info
(
"Capturing model for Eagle speculator..."
)
self
.
cudagraph_manager
.
capture
(
# Capture the decode draft generation loop (model forward +
# compute_logits + gumbel_sample + update_eagle_inputs, for
# each step).
# For FULL graphs, the entire multi-step loop is recorded as
# one graph. For PIECEWISE, only the model's compiled regions
# are captured, and the rest (compute_logits, gumbel_sample,
# update_eagle_inputs) runs eagerly.
assert
self
.
decode_cudagraph_manager
is
not
None
self
.
decode_cudagraph_manager
.
capture
(
self
.
generate_draft
,
self
.
model_state
,
self
.
input_buffers
,
self
.
block_tables
,
self
.
attn_groups
,
self
.
kv_cache_config
,
progress_bar_desc
=
"Capturing eagle CUDA graphs"
,
progress_bar_desc
=
"Capturing eagle
decode
CUDA graphs"
,
)
@
torch
.
inference_mode
()
...
...
@@ -324,6 +410,10 @@ class EagleSpeculator:
mm_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
is_profile
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
=
input_batch
.
num_tokens_after_padding
num_reqs
=
input_batch
.
num_reqs
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
()
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
# hidden_states the same as the target model's. This means, we pad each
...
...
@@ -337,82 +427,88 @@ class EagleSpeculator:
)
else
:
hidden_states
=
last_hidden_states
num_tokens
=
input_batch
.
num_tokens_after_padding
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
self
.
hidden_states
[:
num_tokens
].
copy_
(
hidden_states
)
# Copy temperature, seeds, and idx mapping to the pre-allocated buffers.
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
self
.
temperature
.
copy_
(
temperature
)
self
.
seeds
.
copy_
(
seeds
)
self
.
idx_mapping
[:
num_reqs
].
copy_
(
input_batch
.
idx_mapping
)
# Get the input ids and last token indices for the speculator.
last_token_indices
=
prepare_eagle_inputs
(
prepare_eagle_inputs
(
self
.
input_buffers
,
input_batch
,
self
.
last_token_indices
,
num_sampled
,
num_rejected
,
last_sampled
,
next_prefill_tokens
,
self
.
max_num_reqs
,
)
# Prefill: Run the eagle speculator with eager mode.
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
# When all requests are decoding (no true prefills), each has
# num_speculative_steps + 1 tokens, enabling FULL graph replay.
# Mixed or prefill-only batches fall back to PIECEWISE.
prefill_batch_desc
,
num_tokens_across_dp
=
dispatch_cg_and_sync_dp
(
self
.
prefill_cudagraph_manager
,
num_reqs
,
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
=
num_tokens_across_dp
,
mm_inputs
=
mm_inputs
,
get_uniform_token_count
(
num_reqs
,
num_tokens
,
max_query_len
)
,
dp_size
=
self
.
dp_size
,
dp_rank
=
self
.
dp_rank
,
need_eager
=
is_profile
,
)
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
num_reqs
=
input_batch
.
num_reqs
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
idx_mapping
=
self
.
idx_mapping
[:
num_reqs
]
idx_mapping
.
copy_
(
input_batch
.
idx_mapping
)
self
.
temperature
.
copy_
(
temperature
)
self
.
seeds
.
copy_
(
seeds
)
# Gather the values and copy them to the pre-allocated buffers.
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
torch
.
gather
(
input_batch
.
positions
,
0
,
last_token_indices
,
out
=
pos
)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
logits
,
idx_mapping
,
self
.
temperature
,
self
.
seeds
,
pos
+
1
,
apply_temperature
=
True
,
processed_logits_out
=
self
.
draft_logits
[:,
0
]
if
self
.
draft_logits
is
not
None
else
None
,
)
if
prefill_batch_desc
.
cg_mode
==
CUDAGraphMode
.
FULL
:
# It is necessary to rebuild the attention metadata when
# replaying the FULL graph so that any attention metadata
# builder state is updated.
self
.
_build_draft_attn_metadata
(
num_reqs
=
num_reqs
,
num_reqs_padded
=
prefill_batch_desc
.
num_reqs
or
num_reqs
,
num_tokens_padded
=
prefill_batch_desc
.
num_tokens
,
max_query_len
=
self
.
num_speculative_steps
+
1
,
)
# Replay the full graph for draft prefill.
assert
self
.
prefill_cudagraph_manager
is
not
None
self
.
prefill_cudagraph_manager
.
run_fullgraph
(
prefill_batch_desc
)
else
:
# The target model's attention metadata and slot mappings
# can directly be used for draft prefill, because of the
# identical batch shape and KV cache layout.
self
.
prefill
(
num_reqs
,
prefill_batch_desc
.
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
prefill_batch_desc
.
cg_mode
,
mm_inputs
=
mm_inputs
,
)
if
self
.
num_speculative_steps
==
1
:
# Early exit.
return
draft_tokens
.
view
(
-
1
,
1
)
return
self
.
draft_tokens
[:
num_reqs
,
:
1
]
# Save the draft tokens for the first step.
self
.
draft_tokens
[:
num_reqs
,
0
]
=
draft_tokens
# Prepare the inputs for the decode steps.
prepare_eagle_decode
(
draft_tokens
,
hidden_states
,
last_token_indices
,
self
.
draft_tokens
[:
num_reqs
,
0
],
input_batch
.
seq_lens
,
num_rejected
,
self
.
input_buffers
,
self
.
hidden_states
,
self
.
max_model_len
,
self
.
max_num_reqs
,
)
# Each request produces exactly 1 token per draft
decode
step,
# enabling FULL
cuda
graph.
# Each request produces exactly 1 token per draft
generation
step,
# enabling FULL graph
replay
.
decode_batch_desc
,
num_tokens_across_dp
=
dispatch_cg_and_sync_dp
(
self
.
cudagraph_manager
,
self
.
decode_
cudagraph_manager
,
num_reqs
,
num_reqs
,
uniform_token_count
=
1
,
...
...
@@ -426,12 +522,12 @@ class EagleSpeculator:
if
not
(
dummy_run
and
skip_attn_for_dummy_run
):
# Build attention metadata and slot mappings for the draft
# decode steps. It is necessary to rebuild the attention
# metadata even when replaying the FULL
cuda
graph so that
#
any
attention metadata builder state is updated.
# metadata even when replaying the FULL graph so that
any
# attention metadata builder state is updated.
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
idx_mapping
,
self
.
idx_mapping
[:
num_reqs
]
,
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
],
pos
,
self
.
input_buffers
.
positions
[:
num_reqs
]
,
decode_batch_desc
.
num_tokens
,
)
slot_mappings_updated
=
build_slot_mappings_by_layer
(
...
...
@@ -445,8 +541,9 @@ class EagleSpeculator:
)
if
decode_batch_desc
.
cg_mode
==
CUDAGraphMode
.
FULL
:
assert
self
.
cudagraph_manager
is
not
None
self
.
cudagraph_manager
.
run_fullgraph
(
decode_batch_desc
)
# Replay the full graph for draft generation.
assert
self
.
decode_cudagraph_manager
is
not
None
self
.
decode_cudagraph_manager
.
run_fullgraph
(
decode_batch_desc
)
else
:
self
.
generate_draft
(
num_reqs
,
...
...
@@ -464,6 +561,8 @@ def _prepare_eagle_inputs_kernel(
last_token_indices_ptr
,
eagle_input_ids_ptr
,
eagle_positions_ptr
,
eagle_query_start_loc_ptr
,
eagle_seq_lens_ptr
,
target_input_ids_ptr
,
target_positions_ptr
,
idx_mapping_ptr
,
...
...
@@ -472,20 +571,24 @@ def _prepare_eagle_inputs_kernel(
num_sampled_ptr
,
num_rejected_ptr
,
query_start_loc_ptr
,
seq_lens_ptr
,
max_num_reqs
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
req_idx
=
tl
.
program_id
(
0
)
num_reqs
=
tl
.
num_programs
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
batch
_idx
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch
_idx
+
1
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
req
_idx
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
req
_idx
+
1
)
query_len
=
query_end
-
query_start
seq_len
=
tl
.
load
(
seq_lens_ptr
+
req_idx
)
# Get the true query length and next token after accounting for rejected tokens.
num_rejected
=
tl
.
load
(
num_rejected_ptr
+
batch
_idx
)
num_rejected
=
tl
.
load
(
num_rejected_ptr
+
req
_idx
)
query_len
-=
num_rejected
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch
_idx
)
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
req
_idx
)
if
num_sampled
>
0
:
next_token
=
tl
.
load
(
last_sampled_ptr
+
req_state_idx
).
to
(
tl
.
int32
)
else
:
...
...
@@ -501,7 +604,7 @@ def _prepare_eagle_inputs_kernel(
tl
.
store
(
eagle_input_ids_ptr
+
query_start
+
block
-
1
,
input_ids
,
mask
=
mask
)
last_token_index
=
query_start
+
query_len
-
1
tl
.
store
(
last_token_indices_ptr
+
batch
_idx
,
last_token_index
)
tl
.
store
(
last_token_indices_ptr
+
req
_idx
,
last_token_index
)
tl
.
store
(
eagle_input_ids_ptr
+
last_token_index
,
next_token
)
# Copy positions.
...
...
@@ -511,11 +614,29 @@ def _prepare_eagle_inputs_kernel(
target_pos
=
tl
.
load
(
target_positions_ptr
+
query_start
+
block
,
mask
=
mask
)
tl
.
store
(
eagle_positions_ptr
+
query_start
+
block
,
target_pos
,
mask
=
mask
)
# Copy query start locations.
tl
.
store
(
eagle_query_start_loc_ptr
+
req_idx
,
query_start
)
# Copy sequence lengths.
tl
.
store
(
eagle_seq_lens_ptr
+
req_idx
,
seq_len
)
if
req_idx
==
(
num_reqs
-
1
):
# Pad query_start_loc for CUDA graphs.
for
i
in
range
(
num_reqs
,
max_num_reqs
+
1
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
max_num_reqs
+
1
tl
.
store
(
eagle_query_start_loc_ptr
+
block
,
query_end
,
mask
=
mask
)
# Pad seq_lens for CUDA graphs.
for
i
in
range
(
num_reqs
,
max_num_reqs
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
max_num_reqs
tl
.
store
(
eagle_seq_lens_ptr
+
block
,
0
,
mask
=
mask
)
def
prepare_eagle_inputs
(
input_buffers
:
InputBuffers
,
input_batch
:
InputBatch
,
# [num_reqs]
last_token_indices
:
torch
.
Tensor
,
# [num_reqs]
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
...
...
@@ -523,17 +644,15 @@ def prepare_eagle_inputs(
last_sampled
:
torch
.
Tensor
,
# [max_num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
max_num_reqs
,
)
->
torch
.
Tensor
:
num_reqs
=
input_batch
.
num_reqs
last_token_indices
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
num_sampled
.
device
,
)
_prepare_eagle_inputs_kernel
[(
num_reqs
,)](
last_token_indices
,
input_buffers
.
input_ids
,
input_buffers
.
positions
,
input_buffers
.
query_start_loc
,
input_buffers
.
seq_lens
,
input_batch
.
input_ids
,
input_batch
.
positions
,
input_batch
.
idx_mapping
,
...
...
@@ -542,6 +661,8 @@ def prepare_eagle_inputs(
num_sampled
,
num_rejected
,
input_batch
.
query_start_loc
,
input_batch
.
seq_lens
,
max_num_reqs
,
BLOCK_SIZE
=
1024
,
)
return
last_token_indices
...
...
@@ -550,18 +671,13 @@ def prepare_eagle_inputs(
@
triton
.
jit
def
_prepare_eagle_docode_kernel
(
draft_tokens_ptr
,
output_hidden_states_ptr
,
output_hidden_states_stride
,
last_token_indices_ptr
,
draft_tokens_stride
,
target_seq_lens_ptr
,
num_rejected_ptr
,
input_ids_ptr
,
positions_ptr
,
input_hidden_states_ptr
,
input_hidden_states_stride
,
query_start_loc_ptr
,
seq_lens_ptr
,
hidden_size
,
max_model_len
,
max_num_reqs
,
BLOCK_SIZE
:
tl
.
constexpr
,
...
...
@@ -584,24 +700,9 @@ def _prepare_eagle_docode_kernel(
return
# draft token -> input id.
draft_token
=
tl
.
load
(
draft_tokens_ptr
+
req_idx
)
draft_token
=
tl
.
load
(
draft_tokens_ptr
+
req_idx
*
draft_tokens_stride
)
tl
.
store
(
input_ids_ptr
+
req_idx
,
draft_token
)
# output hidden states -> input hidden states.
src_idx
=
tl
.
load
(
last_token_indices_ptr
+
req_idx
)
for
i
in
range
(
0
,
hidden_size
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
hidden_size
output_hidden_states
=
tl
.
load
(
output_hidden_states_ptr
+
src_idx
*
output_hidden_states_stride
+
block
,
mask
=
mask
,
)
tl
.
store
(
input_hidden_states_ptr
+
req_idx
*
input_hidden_states_stride
+
block
,
output_hidden_states
,
mask
=
mask
,
)
# Compute position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
...
...
@@ -618,31 +719,22 @@ def _prepare_eagle_docode_kernel(
def
prepare_eagle_decode
(
draft_tokens
:
torch
.
Tensor
,
output_hidden_states
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
,
target_seq_lens
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
input_buffers
:
InputBuffers
,
input_hidden_states
:
torch
.
Tensor
,
max_model_len
:
int
,
max_num_reqs
:
int
,
):
num_reqs
=
draft_tokens
.
shape
[
0
]
hidden_size
=
output_hidden_states
.
shape
[
-
1
]
_prepare_eagle_docode_kernel
[(
num_reqs
+
1
,)](
draft_tokens
,
output_hidden_states
,
output_hidden_states
.
stride
(
0
),
last_token_indices
,
draft_tokens
.
stride
(
0
),
target_seq_lens
,
num_rejected
,
input_buffers
.
input_ids
,
input_buffers
.
positions
,
input_hidden_states
,
input_hidden_states
.
stride
(
0
),
input_buffers
.
query_start_loc
,
input_buffers
.
seq_lens
,
hidden_size
,
max_model_len
,
max_num_reqs
,
BLOCK_SIZE
=
1024
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment