Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da3222f3
Unverified
Commit
da3222f3
authored
Nov 27, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 27, 2025
Browse files
[Model Runner V2] Implement multi-step Eagle with CUDA graph (#29559)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
43c57925
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
526 additions
and
70 deletions
+526
-70
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+5
-4
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+19
-34
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+390
-32
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
+112
-0
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
da3222f3
...
@@ -233,10 +233,11 @@ def prepare_inputs_to_capture(
...
@@ -233,10 +233,11 @@ def prepare_inputs_to_capture(
query_start_loc
.
np
[
num_reqs
:]
=
num_tokens
query_start_loc
.
np
[
num_reqs
:]
=
num_tokens
query_start_loc
.
copy_to_gpu
()
query_start_loc
.
copy_to_gpu
()
seq_lens_np
=
np
.
full
(
num_reqs
,
max_model_len
,
dtype
=
np
.
int32
)
seq_lens_np
=
np
.
full
(
num_reqs
,
max_model_len
,
dtype
=
np
.
int32
)
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
# rather than max_model_len. This introduces a discrepancy between
# seq_lens_np (CPU), which might cause issues in some attention backends.
# seq_lens (on GPU) and seq_lens_np (on CPU), which may cause issues for
input_buffers
.
seq_lens
[:
num_reqs
]
=
1
# certain attention backends.
input_buffers
.
seq_lens
[:
num_reqs
]
=
num_tokens
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
input_block_tables
=
[
x
[:
num_reqs
]
for
x
in
block_tables
.
input_block_tables
]
input_block_tables
=
[
x
[:
num_reqs
]
for
x
in
block_tables
.
input_block_tables
]
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
da3222f3
...
@@ -140,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -140,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
# CUDA graphs.
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
vllm_config
,
self
.
device
)
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
,
)
def
get_supported_tasks
(
self
)
->
tuple
[
str
]:
def
get_supported_tasks
(
self
)
->
tuple
[
str
]:
return
(
"generate"
,)
return
(
"generate"
,)
...
@@ -203,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -203,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
vllm_config
,
self
.
vllm_config
,
self
.
device
,
self
.
device
,
)
)
if
self
.
do_spec_decode
:
# HACK(woosuk)
self
.
speculator
.
set_attn
(
self
.
kv_cache_config
,
self
.
attn_metadata_builders
,
self
.
block_tables
,
)
# TODO(woosuk): Support other backends.
# TODO(woosuk): Support other backends.
if
not
all
(
b
.
get_name
()
==
"FLASH_ATTN"
for
b
in
self
.
attn_backends
.
values
()):
if
not
all
(
b
.
get_name
()
==
"FLASH_ATTN"
for
b
in
self
.
attn_backends
.
values
()):
raise
NotImplementedError
(
"Only FLASH_ATTN backend is supported currently."
)
raise
NotImplementedError
(
"Only FLASH_ATTN backend is supported currently."
)
...
@@ -297,35 +302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -297,35 +302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
hidden_states
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
)
self
.
sampler
(
logits
,
sampling_metadata
)
self
.
sampler
(
logits
,
sampling_metadata
)
@
torch
.
inference_mode
()
def
_dummy_speculator_run
(
self
,
hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
)
->
None
:
num_tokens
=
hidden_states
.
shape
[
0
]
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
input_batch
=
InputBatch
.
make_dummy
(
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
input_buffers
=
self
.
input_buffers
,
device
=
self
.
device
,
)
sampling_metadata
=
SamplingMetadata
.
make_dummy
(
num_reqs
=
num_reqs
,
device
=
self
.
device
,
)
num_sampled
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_rejected
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
propose_draft
(
input_batch
=
input_batch
,
sampling_metadata
=
sampling_metadata
,
last_hidden_states
=
hidden_states
,
aux_hidden_states
=
aux_hidden_states
,
num_sampled
=
num_sampled
,
num_rejected
=
num_rejected
,
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
hidden_states
,
sample_hidden_states
=
self
.
_dummy_run
(
hidden_states
,
sample_hidden_states
=
self
.
_dummy_run
(
...
@@ -334,7 +310,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -334,7 +310,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
self
.
_dummy_sampler_run
(
sample_hidden_states
)
self
.
_dummy_sampler_run
(
sample_hidden_states
)
if
self
.
do_spec_decode
:
if
self
.
do_spec_decode
:
self
.
_dummy_speculator_run
(
hidden_states
,
None
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
self
.
max_num_tokens
)
self
.
speculator
.
run_model
(
self
.
max_num_tokens
,
attn_metadata
=
None
,
num_tokens_across_dp
=
num_tokens_across_dp
,
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
del
hidden_states
,
sample_hidden_states
del
hidden_states
,
sample_hidden_states
gc
.
collect
()
gc
.
collect
()
...
@@ -368,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -368,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builders
=
self
.
attn_metadata_builders
,
attn_metadata_builders
=
self
.
attn_metadata_builders
,
kv_cache_config
=
self
.
kv_cache_config
,
kv_cache_config
=
self
.
kv_cache_config
,
)
)
if
self
.
do_spec_decode
:
self
.
speculator
.
capture_model
()
end_time
=
time
.
perf_counter
()
end_time
=
time
.
perf_counter
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
da3222f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.sampler
import
gumbel_sample
from
vllm.v1.worker.gpu.sampler
import
gumbel_sample
from
vllm.v1.worker.gpu.spec_decode.eagle_cudagraph
import
EagleCudaGraphManager
from
vllm.v1.worker.gpu.states
import
SamplingMetadata
from
vllm.v1.worker.gpu.states
import
SamplingMetadata
logger
=
init_logger
(
__name__
)
class
EagleSpeculator
:
class
EagleSpeculator
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
@@ -27,14 +39,49 @@ class EagleSpeculator:
...
@@ -27,14 +39,49 @@ class EagleSpeculator:
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
vocab_size
=
self
.
draft_model_config
.
get_vocab_size
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
input_ids
=
torch
.
zeros
(
self
.
input_buffers
=
InputBuffers
(
self
.
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
max_num_reqs
=
self
.
max_num_reqs
,
max_num_tokens
=
self
.
max_num_tokens
,
hidden_size
=
self
.
hidden_size
,
vocab_size
=
self
.
vocab_size
,
dtype
=
self
.
dtype
,
device
=
device
,
pin_memory
=
self
.
pin_memory
,
)
self
.
hidden_states
=
torch
.
zeros
(
self
.
max_num_tokens
,
self
.
hidden_size
,
dtype
=
self
.
dtype
,
device
=
device
,
)
self
.
temperature
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
,
)
)
self
.
positions
=
torch
.
zeros
(
self
.
seeds
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
self
.
max_num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
draft_tokens
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
num_speculative_steps
,
dtype
=
torch
.
int64
,
device
=
device
,
)
)
self
.
cudagraph_manager
=
EagleCudaGraphManager
(
vllm_config
,
device
)
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
from
vllm.compilation.backends
import
set_model_tag
from
vllm.compilation.backends
import
set_model_tag
...
@@ -49,6 +96,91 @@ class EagleSpeculator:
...
@@ -49,6 +96,91 @@ class EagleSpeculator:
del
self
.
model
.
lm_head
del
self
.
model
.
lm_head
self
.
model
.
lm_head
=
target_model
.
lm_head
self
.
model
.
lm_head
=
target_model
.
lm_head
def
set_attn
(
self
,
kv_cache_config
:
KVCacheConfig
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
block_tables
:
BlockTables
,
)
->
None
:
self
.
kv_cache_config
=
kv_cache_config
self
.
attn_metadata_builders
=
attn_metadata_builders
self
.
block_tables
=
block_tables
@
torch
.
inference_mode
()
def
run_model
(
self
,
num_tokens
:
int
,
attn_metadata
:
dict
[
str
,
Any
],
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
num_tokens_across_dp
=
num_tokens_across_dp
,
):
ret_hidden_states
=
self
.
model
(
input_ids
=
self
.
input_buffers
.
input_ids
.
gpu
[:
num_tokens
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_tokens
],
)
if
self
.
method
==
"mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
return
last_hidden_states
,
hidden_states
def
generate_draft
(
self
,
num_reqs
:
int
,
attn_metadata
:
dict
[
str
,
Any
],
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
)
->
None
:
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
for
step
in
range
(
1
,
self
.
num_speculative_steps
):
# Run the eagle model.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
num_reqs
,
attn_metadata
,
num_tokens_across_dp
)
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
logits
,
self
.
temperature
[:
num_reqs
],
self
.
seeds
[:
num_reqs
],
pos
+
1
,
apply_temperature
=
True
,
)
self
.
draft_tokens
[:
num_reqs
,
step
]
=
draft_tokens
if
step
<
self
.
num_speculative_steps
-
1
:
# Update the inputs for the next step.
update_eagle_inputs
(
draft_tokens
,
hidden_states
,
self
.
input_buffers
,
self
.
hidden_states
,
self
.
max_model_len
,
)
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc
,
pos
)
def
capture_model
(
self
)
->
None
:
if
self
.
num_speculative_steps
==
1
:
return
logger
.
info
(
"Capturing model for Eagle speculator..."
)
self
.
cudagraph_manager
.
capture
(
self
.
generate_draft
,
self
.
input_buffers
,
self
.
block_tables
,
self
.
attn_metadata_builders
,
self
.
kv_cache_config
,
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
propose
(
def
propose
(
self
,
self
,
...
@@ -80,64 +212,110 @@ class EagleSpeculator:
...
@@ -80,64 +212,110 @@ class EagleSpeculator:
)
)
else
:
else
:
hidden_states
=
last_hidden_states
hidden_states
=
last_hidden_states
num_tokens
=
input_batch
.
num_tokens_after_padding
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
# Get the input ids and last token indices for the speculator.
# Get the input ids and last token indices for the speculator.
last_token_indices
=
prepare_eagle_inputs
(
last_token_indices
=
prepare_eagle_inputs
(
self
.
input_
id
s
,
self
.
input_
buffer
s
,
input_batch
,
input_batch
,
num_sampled
,
num_sampled
,
num_rejected
,
num_rejected
,
last_sampled
,
last_sampled
,
next_prefill_tokens
,
next_prefill_tokens
,
)
)
input_ids
=
self
.
input_ids
[:
input_batch
.
num_tokens_after_padding
]
# Prefill: Run the eagle speculator with eager mode.
# Prefill: Run the eagle speculator with eager mode.
with
set_forward_context
(
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
num_tokens
,
input_batch
.
attn_metadata
,
input_batch
.
attn_metadata
,
self
.
vllm_config
,
num_tokens_across_dp
=
None
,
# FIXME
num_tokens
=
input_batch
.
num_tokens_after_padding
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
):
ret_hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
input_batch
.
positions
,
hidden_states
=
hidden_states
,
)
)
if
self
.
method
==
"mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
num_reqs
=
input_batch
.
num_reqs
num_reqs
=
input_batch
.
num_reqs
cu_num_logits
=
input_batch
.
cu_num_logits
[:
num_reqs
]
cu_num_logits
=
input_batch
.
cu_num_logits
[:
num_reqs
]
temperature
=
sampling_metadata
.
temperature
[
cu_num_logits
]
seed
=
sampling_metadata
.
seeds
[
cu_num_logits
]
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
pos
=
input_batch
.
positions
[
last_token_indices
]
+
1
# NOTE(woosuk): For draft sampling, we only consider the temperature
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
# affect the output distribution after rejection sampling.
temperature
=
self
.
temperature
[:
num_reqs
]
seeds
=
self
.
seeds
[:
num_reqs
]
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
# Gather the values and copy them to the pre-allocated buffers.
torch
.
gather
(
sampling_metadata
.
temperature
,
0
,
cu_num_logits
,
out
=
temperature
)
torch
.
gather
(
sampling_metadata
.
seeds
,
0
,
cu_num_logits
,
out
=
seeds
)
torch
.
gather
(
input_batch
.
positions
,
0
,
last_token_indices
,
out
=
pos
)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
draft_tokens
=
gumbel_sample
(
logits
,
temperature
,
seed
,
pos
,
apply_temperature
=
True
logits
,
temperature
,
seed
s
,
pos
+
1
,
apply_temperature
=
True
)
)
if
self
.
num_speculative_steps
==
1
:
if
self
.
num_speculative_steps
==
1
:
# Early exit.
# Early exit.
return
draft_tokens
.
view
(
-
1
,
1
)
return
draft_tokens
.
view
(
-
1
,
1
)
raise
NotImplementedError
(
"num_speculative_steps > 1 is not supported yet."
)
# Save the draft tokens for the first step.
self
.
draft_tokens
[:
num_reqs
,
0
]
=
draft_tokens
# Prepare the inputs for the decode steps.
prepare_eagle_decode
(
draft_tokens
,
hidden_states
,
last_token_indices
,
input_batch
.
seq_lens
,
num_rejected
,
self
.
input_buffers
,
self
.
hidden_states
,
self
.
max_model_len
,
self
.
max_num_reqs
,
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
query_start_loc_gpu
=
query_start_loc
.
gpu
[:
num_reqs
+
1
]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc_gpu
,
pos
)
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
num_reqs
)
if
cudagraph_size
is
not
None
:
# Run CUDA graph.
self
.
cudagraph_manager
.
run
(
cudagraph_size
)
return
self
.
draft_tokens
[:
num_reqs
]
# Run eager mode.
query_start_loc
.
np
[:
num_reqs
+
1
]
=
np
.
arange
(
num_reqs
+
1
)
query_start_loc_cpu
=
query_start_loc
.
cpu
[:
num_reqs
+
1
]
# HACK(woosuk)
seq_lens_np
=
np
.
full
(
num_reqs
,
self
.
max_model_len
,
dtype
=
np
.
int32
)
block_tables
=
[
x
[:
num_reqs
]
for
x
in
self
.
block_tables
.
input_block_tables
]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata
=
build_attn_metadata
(
attn_metadata_builders
=
self
.
attn_metadata_builders
,
num_reqs
=
num_reqs
,
num_tokens
=
num_reqs
,
query_start_loc_gpu
=
query_start_loc_gpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
],
seq_lens_np
=
seq_lens_np
,
num_computed_tokens_cpu
=
None
,
# FIXME
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
self
.
kv_cache_config
,
)
self
.
generate_draft
(
num_reqs
,
attn_metadata
,
num_tokens_across_dp
=
None
)
# FIXME
return
self
.
draft_tokens
[:
num_reqs
]
@
triton
.
jit
@
triton
.
jit
def
_prepare_eagle_inputs_kernel
(
def
_prepare_eagle_inputs_kernel
(
last_token_indices_ptr
,
last_token_indices_ptr
,
eagle_input_ids_ptr
,
eagle_input_ids_ptr
,
eagle_positions_ptr
,
target_input_ids_ptr
,
target_input_ids_ptr
,
target_positions_ptr
,
idx_mapping_ptr
,
idx_mapping_ptr
,
last_sampled_ptr
,
last_sampled_ptr
,
next_prefill_tokens_ptr
,
next_prefill_tokens_ptr
,
...
@@ -175,9 +353,16 @@ def _prepare_eagle_inputs_kernel(
...
@@ -175,9 +353,16 @@ def _prepare_eagle_inputs_kernel(
tl
.
store
(
last_token_indices_ptr
+
batch_idx
,
last_token_index
)
tl
.
store
(
last_token_indices_ptr
+
batch_idx
,
last_token_index
)
tl
.
store
(
eagle_input_ids_ptr
+
last_token_index
,
next_token
)
tl
.
store
(
eagle_input_ids_ptr
+
last_token_index
,
next_token
)
# Copy positions.
for
i
in
range
(
0
,
query_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
query_len
target_pos
=
tl
.
load
(
target_positions_ptr
+
query_start
+
block
,
mask
=
mask
)
tl
.
store
(
eagle_positions_ptr
+
query_start
+
block
,
target_pos
,
mask
=
mask
)
def
prepare_eagle_inputs
(
def
prepare_eagle_inputs
(
eagle_input_ids
:
torch
.
Tensor
,
input_buffers
:
InputBuffers
,
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
# [num_reqs]
# [num_reqs]
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
...
@@ -192,12 +377,14 @@ def prepare_eagle_inputs(
...
@@ -192,12 +377,14 @@ def prepare_eagle_inputs(
last_token_indices
=
torch
.
empty
(
last_token_indices
=
torch
.
empty
(
num_reqs
,
num_reqs
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
eagle_input_ids
.
device
,
device
=
num_sampled
.
device
,
)
)
_prepare_eagle_inputs_kernel
[(
num_reqs
,)](
_prepare_eagle_inputs_kernel
[(
num_reqs
,)](
last_token_indices
,
last_token_indices
,
eagle_input_ids
,
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
positions
,
input_batch
.
input_ids
,
input_batch
.
input_ids
,
input_batch
.
positions
,
input_batch
.
idx_mapping
,
input_batch
.
idx_mapping
,
last_sampled
,
last_sampled
,
next_prefill_tokens
,
next_prefill_tokens
,
...
@@ -207,3 +394,174 @@ def prepare_eagle_inputs(
...
@@ -207,3 +394,174 @@ def prepare_eagle_inputs(
BLOCK_SIZE
=
1024
,
BLOCK_SIZE
=
1024
,
)
)
return
last_token_indices
return
last_token_indices
@
triton
.
jit
def
_prepare_eagle_docode_kernel
(
draft_tokens_ptr
,
output_hidden_states_ptr
,
output_hidden_states_stride
,
last_token_indices_ptr
,
target_seq_lens_ptr
,
num_rejected_ptr
,
input_ids_ptr
,
positions_ptr
,
input_hidden_states_ptr
,
input_hidden_states_stride
,
query_start_loc_ptr
,
seq_lens_ptr
,
hidden_size
,
max_model_len
,
max_num_reqs
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
num_reqs
=
tl
.
num_programs
(
0
)
-
1
if
req_idx
==
num_reqs
:
# Compute query_start_loc. Pad it with the last query_start_loc
# for CUDA graphs.
for
i
in
range
(
0
,
max_num_reqs
+
1
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
q
=
tl
.
where
(
block
<
num_reqs
,
block
,
num_reqs
)
mask
=
block
<
max_num_reqs
+
1
tl
.
store
(
query_start_loc_ptr
+
block
,
q
,
mask
=
mask
)
# Pad seq_lens for CUDA graphs.
for
i
in
range
(
req_idx
,
max_num_reqs
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
max_num_reqs
tl
.
store
(
seq_lens_ptr
+
block
,
0
,
mask
=
mask
)
return
# draft token -> input id.
draft_token
=
tl
.
load
(
draft_tokens_ptr
+
req_idx
)
tl
.
store
(
input_ids_ptr
+
req_idx
,
draft_token
)
# output hidden states -> input hidden states.
src_idx
=
tl
.
load
(
last_token_indices_ptr
+
req_idx
)
for
i
in
range
(
0
,
hidden_size
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
hidden_size
output_hidden_states
=
tl
.
load
(
output_hidden_states_ptr
+
src_idx
*
output_hidden_states_stride
+
block
,
mask
=
mask
,
)
tl
.
store
(
input_hidden_states_ptr
+
req_idx
*
input_hidden_states_stride
+
block
,
output_hidden_states
,
mask
=
mask
,
)
# Compute position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position
=
tl
.
load
(
positions_ptr
+
req_idx
)
position
=
tl
.
minimum
(
position
+
1
,
max_model_len
-
1
)
tl
.
store
(
positions_ptr
+
req_idx
,
position
)
target_seq_len
=
tl
.
load
(
target_seq_lens_ptr
+
req_idx
)
num_rejected
=
tl
.
load
(
num_rejected_ptr
+
req_idx
)
seq_len
=
target_seq_len
-
num_rejected
seq_len
=
tl
.
minimum
(
seq_len
+
1
,
max_model_len
)
tl
.
store
(
seq_lens_ptr
+
req_idx
,
seq_len
)
def
prepare_eagle_decode
(
draft_tokens
:
torch
.
Tensor
,
output_hidden_states
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
,
target_seq_lens
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
input_buffers
:
InputBuffers
,
input_hidden_states
:
torch
.
Tensor
,
max_model_len
:
int
,
max_num_reqs
:
int
,
):
num_reqs
=
draft_tokens
.
shape
[
0
]
hidden_size
=
output_hidden_states
.
shape
[
-
1
]
_prepare_eagle_docode_kernel
[(
num_reqs
+
1
,)](
draft_tokens
,
output_hidden_states
,
output_hidden_states
.
stride
(
0
),
last_token_indices
,
target_seq_lens
,
num_rejected
,
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
positions
,
input_hidden_states
,
input_hidden_states
.
stride
(
0
),
input_buffers
.
query_start_loc
.
gpu
,
input_buffers
.
seq_lens
,
hidden_size
,
max_model_len
,
max_num_reqs
,
BLOCK_SIZE
=
1024
,
)
@
triton
.
jit
def
_update_eagle_inputs_kernel
(
input_ids_ptr
,
positions_ptr
,
input_hidden_states_ptr
,
input_hidden_states_stride
,
seq_lens_ptr
,
max_model_len
,
draft_tokens_ptr
,
output_hidden_states_ptr
,
output_hidden_states_stride
,
hidden_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
# Draft token -> Input ID.
draft_token
=
tl
.
load
(
draft_tokens_ptr
+
req_idx
)
tl
.
store
(
input_ids_ptr
+
req_idx
,
draft_token
)
# Output hidden states -> Input hidden states.
for
i
in
range
(
0
,
hidden_size
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
hidden_size
output_hidden_states
=
tl
.
load
(
output_hidden_states_ptr
+
req_idx
*
output_hidden_states_stride
+
block
,
mask
=
mask
,
)
tl
.
store
(
input_hidden_states_ptr
+
req_idx
*
input_hidden_states_stride
+
block
,
output_hidden_states
,
mask
=
mask
,
)
# Increment position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length.
position
=
tl
.
load
(
positions_ptr
+
req_idx
)
position
=
tl
.
minimum
(
position
+
1
,
max_model_len
-
1
)
tl
.
store
(
positions_ptr
+
req_idx
,
position
)
seq_len
=
tl
.
load
(
seq_lens_ptr
+
req_idx
)
seq_len
=
tl
.
minimum
(
seq_len
+
1
,
max_model_len
)
tl
.
store
(
seq_lens_ptr
+
req_idx
,
seq_len
)
def
update_eagle_inputs
(
draft_tokens
:
torch
.
Tensor
,
output_hidden_states
:
torch
.
Tensor
,
input_buffers
:
InputBuffers
,
hidden_states
:
torch
.
Tensor
,
max_model_len
:
int
,
):
num_reqs
,
hidden_size
=
output_hidden_states
.
shape
_update_eagle_inputs_kernel
[(
num_reqs
,)](
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
positions
,
hidden_states
,
hidden_states
.
stride
(
0
),
input_buffers
.
seq_lens
,
max_model_len
,
draft_tokens
,
output_hidden_states
,
output_hidden_states
.
stride
(
0
),
hidden_size
,
BLOCK_SIZE
=
1024
,
)
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
0 → 100644
View file @
da3222f3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
import
torch
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
capture_graphs
,
get_cudagraph_sizes
,
prepare_inputs_to_capture
,
)
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
class
EagleCudaGraphManager
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
self
.
vllm_config
=
vllm_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device
=
device
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
self
.
compilation_config
=
vllm_config
.
compilation_config
assert
self
.
compilation_config
is
not
None
if
self
.
compilation_config
.
cudagraph_mode
is
None
:
self
.
cudagraph_mode
=
CUDAGraphMode
.
NONE
else
:
self
.
cudagraph_mode
=
self
.
compilation_config
.
cudagraph_mode
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self
.
cudagraph_mode
=
CUDAGraphMode
.
FULL_DECODE_ONLY
self
.
cudagraph_sizes
=
get_cudagraph_sizes
(
self
.
compilation_config
.
cudagraph_capture_sizes
,
self
.
max_num_reqs
,
self
.
max_num_tokens
,
self
.
cudagraph_mode
,
)
self
.
graphs
:
dict
[
int
,
torch
.
cuda
.
CUDAGraph
]
=
{}
self
.
pool
=
torch
.
cuda
.
graph_pool_handle
()
def
get_cudagraph_size
(
self
,
num_tokens
:
int
)
->
int
|
None
:
return
self
.
cudagraph_sizes
.
get
(
num_tokens
)
def
capture_graph
(
self
,
num_tokens
:
int
,
generate_fn
:
Callable
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
attn_metadata
=
prepare_inputs_to_capture
(
num_reqs
,
num_tokens
,
input_buffers
,
block_tables
,
attn_metadata_builders
,
self
.
max_model_len
,
kv_cache_config
,
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
num_tokens
)
# Warm up.
generate_fn
(
num_tokens
,
attn_metadata
,
num_tokens_across_dp
)
# Capture the graph.
assert
num_tokens
not
in
self
.
graphs
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
self
.
pool
):
generate_fn
(
num_tokens
,
attn_metadata
,
num_tokens_across_dp
)
self
.
graphs
[
num_tokens
]
=
graph
@
torch
.
inference_mode
()
def
capture
(
self
,
generate_fn
:
Callable
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
capture_graphs
(
self
.
cudagraph_sizes
,
self
.
device
,
self
.
capture_graph
,
generate_fn
=
generate_fn
,
input_buffers
=
input_buffers
,
block_tables
=
block_tables
,
attn_metadata_builders
=
attn_metadata_builders
,
kv_cache_config
=
kv_cache_config
,
)
def
run
(
self
,
num_tokens
:
int
)
->
None
:
assert
num_tokens
in
self
.
graphs
self
.
graphs
[
num_tokens
].
replay
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment