Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da3222f3
Unverified
Commit
da3222f3
authored
Nov 27, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 27, 2025
Browse files
[Model Runner V2] Implement multi-step Eagle with CUDA graph (#29559)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
43c57925
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
526 additions
and
70 deletions
+526
-70
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+5
-4
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+19
-34
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+390
-32
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
+112
-0
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
da3222f3
...
...
@@ -233,10 +233,11 @@ def prepare_inputs_to_capture(
query_start_loc
.
np
[
num_reqs
:]
=
num_tokens
query_start_loc
.
copy_to_gpu
()
seq_lens_np
=
np
.
full
(
num_reqs
,
max_model_len
,
dtype
=
np
.
int32
)
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
# seq_lens_np (CPU), which might cause issues in some attention backends.
input_buffers
.
seq_lens
[:
num_reqs
]
=
1
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len. This introduces a discrepancy between
# seq_lens (on GPU) and seq_lens_np (on CPU), which may cause issues for
# certain attention backends.
input_buffers
.
seq_lens
[:
num_reqs
]
=
num_tokens
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
input_block_tables
=
[
x
[:
num_reqs
]
for
x
in
block_tables
.
input_block_tables
]
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
da3222f3
...
...
@@ -140,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
,
)
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
vllm_config
,
self
.
device
)
def
get_supported_tasks
(
self
)
->
tuple
[
str
]:
return
(
"generate"
,)
...
...
@@ -203,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
vllm_config
,
self
.
device
,
)
if
self
.
do_spec_decode
:
# HACK(woosuk)
self
.
speculator
.
set_attn
(
self
.
kv_cache_config
,
self
.
attn_metadata_builders
,
self
.
block_tables
,
)
# TODO(woosuk): Support other backends.
if
not
all
(
b
.
get_name
()
==
"FLASH_ATTN"
for
b
in
self
.
attn_backends
.
values
()):
raise
NotImplementedError
(
"Only FLASH_ATTN backend is supported currently."
)
...
...
@@ -297,35 +302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
hidden_states
)
self
.
sampler
(
logits
,
sampling_metadata
)
@
torch
.
inference_mode
()
def
_dummy_speculator_run
(
self
,
hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
)
->
None
:
num_tokens
=
hidden_states
.
shape
[
0
]
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
input_batch
=
InputBatch
.
make_dummy
(
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
input_buffers
=
self
.
input_buffers
,
device
=
self
.
device
,
)
sampling_metadata
=
SamplingMetadata
.
make_dummy
(
num_reqs
=
num_reqs
,
device
=
self
.
device
,
)
num_sampled
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_rejected
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
propose_draft
(
input_batch
=
input_batch
,
sampling_metadata
=
sampling_metadata
,
last_hidden_states
=
hidden_states
,
aux_hidden_states
=
aux_hidden_states
,
num_sampled
=
num_sampled
,
num_rejected
=
num_rejected
,
)
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
hidden_states
,
sample_hidden_states
=
self
.
_dummy_run
(
...
...
@@ -334,7 +310,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
self
.
_dummy_sampler_run
(
sample_hidden_states
)
if
self
.
do_spec_decode
:
self
.
_dummy_speculator_run
(
hidden_states
,
None
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
self
.
max_num_tokens
)
self
.
speculator
.
run_model
(
self
.
max_num_tokens
,
attn_metadata
=
None
,
num_tokens_across_dp
=
num_tokens_across_dp
,
)
torch
.
cuda
.
synchronize
()
del
hidden_states
,
sample_hidden_states
gc
.
collect
()
...
...
@@ -368,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builders
=
self
.
attn_metadata_builders
,
kv_cache_config
=
self
.
kv_cache_config
,
)
if
self
.
do_spec_decode
:
self
.
speculator
.
capture_model
()
end_time
=
time
.
perf_counter
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
da3222f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.sampler
import
gumbel_sample
from
vllm.v1.worker.gpu.spec_decode.eagle_cudagraph
import
EagleCudaGraphManager
from
vllm.v1.worker.gpu.states
import
SamplingMetadata
logger
=
init_logger
(
__name__
)
class
EagleSpeculator
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
@@ -27,13 +39,48 @@ class EagleSpeculator:
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
vocab_size
=
self
.
draft_model_config
.
get_vocab_size
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
input_ids
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
self
.
input_buffers
=
InputBuffers
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_tokens
=
self
.
max_num_tokens
,
hidden_size
=
self
.
hidden_size
,
vocab_size
=
self
.
vocab_size
,
dtype
=
self
.
dtype
,
device
=
device
,
pin_memory
=
self
.
pin_memory
,
)
self
.
hidden_states
=
torch
.
zeros
(
self
.
max_num_tokens
,
self
.
hidden_size
,
dtype
=
self
.
dtype
,
device
=
device
,
)
self
.
temperature
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
,
)
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
self
.
seeds
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
draft_tokens
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
num_speculative_steps
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
cudagraph_manager
=
EagleCudaGraphManager
(
vllm_config
,
device
)
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
from
vllm.compilation.backends
import
set_model_tag
...
...
@@ -49,6 +96,91 @@ class EagleSpeculator:
del
self
.
model
.
lm_head
self
.
model
.
lm_head
=
target_model
.
lm_head
def
set_attn
(
self
,
kv_cache_config
:
KVCacheConfig
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
block_tables
:
BlockTables
,
)
->
None
:
self
.
kv_cache_config
=
kv_cache_config
self
.
attn_metadata_builders
=
attn_metadata_builders
self
.
block_tables
=
block_tables
@
torch
.
inference_mode
()
def
run_model
(
self
,
num_tokens
:
int
,
attn_metadata
:
dict
[
str
,
Any
],
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
num_tokens_across_dp
=
num_tokens_across_dp
,
):
ret_hidden_states
=
self
.
model
(
input_ids
=
self
.
input_buffers
.
input_ids
.
gpu
[:
num_tokens
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_tokens
],
)
if
self
.
method
==
"mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
return
last_hidden_states
,
hidden_states
def
generate_draft
(
self
,
num_reqs
:
int
,
attn_metadata
:
dict
[
str
,
Any
],
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
)
->
None
:
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
for
step
in
range
(
1
,
self
.
num_speculative_steps
):
# Run the eagle model.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
num_reqs
,
attn_metadata
,
num_tokens_across_dp
)
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
logits
,
self
.
temperature
[:
num_reqs
],
self
.
seeds
[:
num_reqs
],
pos
+
1
,
apply_temperature
=
True
,
)
self
.
draft_tokens
[:
num_reqs
,
step
]
=
draft_tokens
if
step
<
self
.
num_speculative_steps
-
1
:
# Update the inputs for the next step.
update_eagle_inputs
(
draft_tokens
,
hidden_states
,
self
.
input_buffers
,
self
.
hidden_states
,
self
.
max_model_len
,
)
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc
,
pos
)
def
capture_model
(
self
)
->
None
:
if
self
.
num_speculative_steps
==
1
:
return
logger
.
info
(
"Capturing model for Eagle speculator..."
)
self
.
cudagraph_manager
.
capture
(
self
.
generate_draft
,
self
.
input_buffers
,
self
.
block_tables
,
self
.
attn_metadata_builders
,
self
.
kv_cache_config
,
)
@
torch
.
inference_mode
()
def
propose
(
self
,
...
...
@@ -80,64 +212,110 @@ class EagleSpeculator:
)
else
:
hidden_states
=
last_hidden_states
num_tokens
=
input_batch
.
num_tokens_after_padding
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
# Get the input ids and last token indices for the speculator.
last_token_indices
=
prepare_eagle_inputs
(
self
.
input_
id
s
,
self
.
input_
buffer
s
,
input_batch
,
num_sampled
,
num_rejected
,
last_sampled
,
next_prefill_tokens
,
)
input_ids
=
self
.
input_ids
[:
input_batch
.
num_tokens_after_padding
]
# Prefill: Run the eagle speculator with eager mode.
with
set_forward_context
(
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
num_tokens
,
input_batch
.
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch
.
num_tokens_after_padding
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
):
ret_hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
input_batch
.
positions
,
hidden_states
=
hidden_states
,
)
if
self
.
method
==
"mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
num_tokens_across_dp
=
None
,
# FIXME
)
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
num_reqs
=
input_batch
.
num_reqs
cu_num_logits
=
input_batch
.
cu_num_logits
[:
num_reqs
]
temperature
=
sampling_metadata
.
temperature
[
cu_num_logits
]
seed
=
sampling_metadata
.
seeds
[
cu_num_logits
]
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
pos
=
input_batch
.
positions
[
last_token_indices
]
+
1
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
temperature
=
self
.
temperature
[:
num_reqs
]
seeds
=
self
.
seeds
[:
num_reqs
]
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
# Gather the values and copy them to the pre-allocated buffers.
torch
.
gather
(
sampling_metadata
.
temperature
,
0
,
cu_num_logits
,
out
=
temperature
)
torch
.
gather
(
sampling_metadata
.
seeds
,
0
,
cu_num_logits
,
out
=
seeds
)
torch
.
gather
(
input_batch
.
positions
,
0
,
last_token_indices
,
out
=
pos
)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
logits
,
temperature
,
seed
,
pos
,
apply_temperature
=
True
logits
,
temperature
,
seed
s
,
pos
+
1
,
apply_temperature
=
True
)
if
self
.
num_speculative_steps
==
1
:
# Early exit.
return
draft_tokens
.
view
(
-
1
,
1
)
raise
NotImplementedError
(
"num_speculative_steps > 1 is not supported yet."
)
# Save the draft tokens for the first step.
self
.
draft_tokens
[:
num_reqs
,
0
]
=
draft_tokens
# Prepare the inputs for the decode steps.
prepare_eagle_decode
(
draft_tokens
,
hidden_states
,
last_token_indices
,
input_batch
.
seq_lens
,
num_rejected
,
self
.
input_buffers
,
self
.
hidden_states
,
self
.
max_model_len
,
self
.
max_num_reqs
,
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
query_start_loc_gpu
=
query_start_loc
.
gpu
[:
num_reqs
+
1
]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc_gpu
,
pos
)
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
num_reqs
)
if
cudagraph_size
is
not
None
:
# Run CUDA graph.
self
.
cudagraph_manager
.
run
(
cudagraph_size
)
return
self
.
draft_tokens
[:
num_reqs
]
# Run eager mode.
query_start_loc
.
np
[:
num_reqs
+
1
]
=
np
.
arange
(
num_reqs
+
1
)
query_start_loc_cpu
=
query_start_loc
.
cpu
[:
num_reqs
+
1
]
# HACK(woosuk)
seq_lens_np
=
np
.
full
(
num_reqs
,
self
.
max_model_len
,
dtype
=
np
.
int32
)
block_tables
=
[
x
[:
num_reqs
]
for
x
in
self
.
block_tables
.
input_block_tables
]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata
=
build_attn_metadata
(
attn_metadata_builders
=
self
.
attn_metadata_builders
,
num_reqs
=
num_reqs
,
num_tokens
=
num_reqs
,
query_start_loc_gpu
=
query_start_loc_gpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
],
seq_lens_np
=
seq_lens_np
,
num_computed_tokens_cpu
=
None
,
# FIXME
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
self
.
kv_cache_config
,
)
self
.
generate_draft
(
num_reqs
,
attn_metadata
,
num_tokens_across_dp
=
None
)
# FIXME
return
self
.
draft_tokens
[:
num_reqs
]
@
triton
.
jit
def
_prepare_eagle_inputs_kernel
(
last_token_indices_ptr
,
eagle_input_ids_ptr
,
eagle_positions_ptr
,
target_input_ids_ptr
,
target_positions_ptr
,
idx_mapping_ptr
,
last_sampled_ptr
,
next_prefill_tokens_ptr
,
...
...
@@ -175,9 +353,16 @@ def _prepare_eagle_inputs_kernel(
tl
.
store
(
last_token_indices_ptr
+
batch_idx
,
last_token_index
)
tl
.
store
(
eagle_input_ids_ptr
+
last_token_index
,
next_token
)
# Copy positions.
for
i
in
range
(
0
,
query_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
query_len
target_pos
=
tl
.
load
(
target_positions_ptr
+
query_start
+
block
,
mask
=
mask
)
tl
.
store
(
eagle_positions_ptr
+
query_start
+
block
,
target_pos
,
mask
=
mask
)
def
prepare_eagle_inputs
(
eagle_input_ids
:
torch
.
Tensor
,
input_buffers
:
InputBuffers
,
input_batch
:
InputBatch
,
# [num_reqs]
num_sampled
:
torch
.
Tensor
,
...
...
@@ -192,12 +377,14 @@ def prepare_eagle_inputs(
last_token_indices
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
eagle_input_ids
.
device
,
device
=
num_sampled
.
device
,
)
_prepare_eagle_inputs_kernel
[(
num_reqs
,)](
last_token_indices
,
eagle_input_ids
,
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
positions
,
input_batch
.
input_ids
,
input_batch
.
positions
,
input_batch
.
idx_mapping
,
last_sampled
,
next_prefill_tokens
,
...
...
@@ -207,3 +394,174 @@ def prepare_eagle_inputs(
BLOCK_SIZE
=
1024
,
)
return
last_token_indices
@
triton
.
jit
def
_prepare_eagle_docode_kernel
(
draft_tokens_ptr
,
output_hidden_states_ptr
,
output_hidden_states_stride
,
last_token_indices_ptr
,
target_seq_lens_ptr
,
num_rejected_ptr
,
input_ids_ptr
,
positions_ptr
,
input_hidden_states_ptr
,
input_hidden_states_stride
,
query_start_loc_ptr
,
seq_lens_ptr
,
hidden_size
,
max_model_len
,
max_num_reqs
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
num_reqs
=
tl
.
num_programs
(
0
)
-
1
if
req_idx
==
num_reqs
:
# Compute query_start_loc. Pad it with the last query_start_loc
# for CUDA graphs.
for
i
in
range
(
0
,
max_num_reqs
+
1
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
q
=
tl
.
where
(
block
<
num_reqs
,
block
,
num_reqs
)
mask
=
block
<
max_num_reqs
+
1
tl
.
store
(
query_start_loc_ptr
+
block
,
q
,
mask
=
mask
)
# Pad seq_lens for CUDA graphs.
for
i
in
range
(
req_idx
,
max_num_reqs
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
max_num_reqs
tl
.
store
(
seq_lens_ptr
+
block
,
0
,
mask
=
mask
)
return
# draft token -> input id.
draft_token
=
tl
.
load
(
draft_tokens_ptr
+
req_idx
)
tl
.
store
(
input_ids_ptr
+
req_idx
,
draft_token
)
# output hidden states -> input hidden states.
src_idx
=
tl
.
load
(
last_token_indices_ptr
+
req_idx
)
for
i
in
range
(
0
,
hidden_size
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
hidden_size
output_hidden_states
=
tl
.
load
(
output_hidden_states_ptr
+
src_idx
*
output_hidden_states_stride
+
block
,
mask
=
mask
,
)
tl
.
store
(
input_hidden_states_ptr
+
req_idx
*
input_hidden_states_stride
+
block
,
output_hidden_states
,
mask
=
mask
,
)
# Compute position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position
=
tl
.
load
(
positions_ptr
+
req_idx
)
position
=
tl
.
minimum
(
position
+
1
,
max_model_len
-
1
)
tl
.
store
(
positions_ptr
+
req_idx
,
position
)
target_seq_len
=
tl
.
load
(
target_seq_lens_ptr
+
req_idx
)
num_rejected
=
tl
.
load
(
num_rejected_ptr
+
req_idx
)
seq_len
=
target_seq_len
-
num_rejected
seq_len
=
tl
.
minimum
(
seq_len
+
1
,
max_model_len
)
tl
.
store
(
seq_lens_ptr
+
req_idx
,
seq_len
)
def
prepare_eagle_decode
(
draft_tokens
:
torch
.
Tensor
,
output_hidden_states
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
,
target_seq_lens
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
input_buffers
:
InputBuffers
,
input_hidden_states
:
torch
.
Tensor
,
max_model_len
:
int
,
max_num_reqs
:
int
,
):
num_reqs
=
draft_tokens
.
shape
[
0
]
hidden_size
=
output_hidden_states
.
shape
[
-
1
]
_prepare_eagle_docode_kernel
[(
num_reqs
+
1
,)](
draft_tokens
,
output_hidden_states
,
output_hidden_states
.
stride
(
0
),
last_token_indices
,
target_seq_lens
,
num_rejected
,
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
positions
,
input_hidden_states
,
input_hidden_states
.
stride
(
0
),
input_buffers
.
query_start_loc
.
gpu
,
input_buffers
.
seq_lens
,
hidden_size
,
max_model_len
,
max_num_reqs
,
BLOCK_SIZE
=
1024
,
)
@
triton
.
jit
def
_update_eagle_inputs_kernel
(
input_ids_ptr
,
positions_ptr
,
input_hidden_states_ptr
,
input_hidden_states_stride
,
seq_lens_ptr
,
max_model_len
,
draft_tokens_ptr
,
output_hidden_states_ptr
,
output_hidden_states_stride
,
hidden_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
# Draft token -> Input ID.
draft_token
=
tl
.
load
(
draft_tokens_ptr
+
req_idx
)
tl
.
store
(
input_ids_ptr
+
req_idx
,
draft_token
)
# Output hidden states -> Input hidden states.
for
i
in
range
(
0
,
hidden_size
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
hidden_size
output_hidden_states
=
tl
.
load
(
output_hidden_states_ptr
+
req_idx
*
output_hidden_states_stride
+
block
,
mask
=
mask
,
)
tl
.
store
(
input_hidden_states_ptr
+
req_idx
*
input_hidden_states_stride
+
block
,
output_hidden_states
,
mask
=
mask
,
)
# Increment position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position
=
tl
.
load
(
positions_ptr
+
req_idx
)
position
=
tl
.
minimum
(
position
+
1
,
max_model_len
-
1
)
tl
.
store
(
positions_ptr
+
req_idx
,
position
)
seq_len
=
tl
.
load
(
seq_lens_ptr
+
req_idx
)
seq_len
=
tl
.
minimum
(
seq_len
+
1
,
max_model_len
)
tl
.
store
(
seq_lens_ptr
+
req_idx
,
seq_len
)
def
update_eagle_inputs
(
draft_tokens
:
torch
.
Tensor
,
output_hidden_states
:
torch
.
Tensor
,
input_buffers
:
InputBuffers
,
hidden_states
:
torch
.
Tensor
,
max_model_len
:
int
,
):
num_reqs
,
hidden_size
=
output_hidden_states
.
shape
_update_eagle_inputs_kernel
[(
num_reqs
,)](
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
positions
,
hidden_states
,
hidden_states
.
stride
(
0
),
input_buffers
.
seq_lens
,
max_model_len
,
draft_tokens
,
output_hidden_states
,
output_hidden_states
.
stride
(
0
),
hidden_size
,
BLOCK_SIZE
=
1024
,
)
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
0 → 100644
View file @
da3222f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
torch
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
capture_graphs
,
get_cudagraph_sizes
,
prepare_inputs_to_capture
,
)
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
class
EagleCudaGraphManager
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
self
.
vllm_config
=
vllm_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device
=
device
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
self
.
compilation_config
=
vllm_config
.
compilation_config
assert
self
.
compilation_config
is
not
None
if
self
.
compilation_config
.
cudagraph_mode
is
None
:
self
.
cudagraph_mode
=
CUDAGraphMode
.
NONE
else
:
self
.
cudagraph_mode
=
self
.
compilation_config
.
cudagraph_mode
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self
.
cudagraph_mode
=
CUDAGraphMode
.
FULL_DECODE_ONLY
self
.
cudagraph_sizes
=
get_cudagraph_sizes
(
self
.
compilation_config
.
cudagraph_capture_sizes
,
self
.
max_num_reqs
,
self
.
max_num_tokens
,
self
.
cudagraph_mode
,
)
self
.
graphs
:
dict
[
int
,
torch
.
cuda
.
CUDAGraph
]
=
{}
self
.
pool
=
torch
.
cuda
.
graph_pool_handle
()
def
get_cudagraph_size
(
self
,
num_tokens
:
int
)
->
int
|
None
:
return
self
.
cudagraph_sizes
.
get
(
num_tokens
)
def
capture_graph
(
self
,
num_tokens
:
int
,
generate_fn
:
Callable
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
attn_metadata
=
prepare_inputs_to_capture
(
num_reqs
,
num_tokens
,
input_buffers
,
block_tables
,
attn_metadata_builders
,
self
.
max_model_len
,
kv_cache_config
,
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
num_tokens
)
# Warm up.
generate_fn
(
num_tokens
,
attn_metadata
,
num_tokens_across_dp
)
# Capture the graph.
assert
num_tokens
not
in
self
.
graphs
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
self
.
pool
):
generate_fn
(
num_tokens
,
attn_metadata
,
num_tokens_across_dp
)
self
.
graphs
[
num_tokens
]
=
graph
@
torch
.
inference_mode
()
def
capture
(
self
,
generate_fn
:
Callable
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
capture_graphs
(
self
.
cudagraph_sizes
,
self
.
device
,
self
.
capture_graph
,
generate_fn
=
generate_fn
,
input_buffers
=
input_buffers
,
block_tables
=
block_tables
,
attn_metadata_builders
=
attn_metadata_builders
,
kv_cache_config
=
kv_cache_config
,
)
def
run
(
self
,
num_tokens
:
int
)
->
None
:
assert
num_tokens
in
self
.
graphs
self
.
graphs
[
num_tokens
].
replay
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment