Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca1b1e72
Unverified
Commit
ca1b1e72
authored
Nov 28, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 28, 2025
Browse files
[Model Runner V2] Refactor prefill token preparation (#29712)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
762a4a6c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
78 deletions
+83
-78
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+1
-1
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+52
-34
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+17
-29
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+8
-11
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+5
-3
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
ca1b1e72
...
...
@@ -78,7 +78,7 @@ class CudaGraphManager:
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
input_ids
=
input_buffers
.
input_ids
.
gpu
[:
num_tokens
]
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
attn_metadata
=
prepare_inputs_to_capture
(
num_reqs
,
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
ca1b1e72
...
...
@@ -3,7 +3,6 @@
from
dataclasses
import
dataclass
from
typing
import
Any
import
numba
import
numpy
as
np
import
torch
...
...
@@ -30,15 +29,12 @@ class InputBuffers:
self
.
pin_memory
=
pin_memory
self
.
idx_mapping
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
input_ids
=
self
.
_make_buffer
(
max_num_tokens
,
dtype
=
torch
.
int32
)
self
.
input_ids
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
query_start_loc
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
cu_num_logits
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
# Spec decoding.
self
.
next_prefill_tokens
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
# Structured outputs.
self
.
bitmask_indices
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
grammar_bitmask
=
self
.
_make_buffer
(
...
...
@@ -120,7 +116,7 @@ class InputBatch:
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
seq_lens
=
input_buffers
.
seq_lens
[:
num_reqs
]
input_ids
=
input_buffers
.
input_ids
.
copy_to_gpu
(
num_tokens
)
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
# attn_metadata = defaultdict(lambda: None)
logits_indices
=
query_start_loc
[
1
:]
-
1
...
...
@@ -146,41 +142,63 @@ class InputBatch:
)
@
numba
.
njit
(
cache
=
True
)
def
_prepare_prefill_inputs
(
idx_mapping
:
np
.
ndarray
,
# [B]
query_lens
:
np
.
ndarray
,
# [B]
query_start_loc
:
np
.
ndarray
,
# [B + 1]
prefill_token_ids
:
np
.
ndarray
,
# [N, max_model_len]
num_computed_prefill_tokens
:
np
.
ndarray
,
# [N]
input_ids
:
np
.
ndarray
,
# [num_input_tokens]
)
->
None
:
num_reqs
=
idx_mapping
.
shape
[
0
]
query_starts
=
query_start_loc
[:
num_reqs
]
query_ends
=
query_start_loc
[
1
:
num_reqs
+
1
]
starts
=
num_computed_prefill_tokens
[
idx_mapping
]
ends
=
starts
+
query_lens
for
i
in
range
(
num_reqs
):
input_ids
[
query_starts
[
i
]
:
query_ends
[
i
]]
=
prefill_token_ids
[
idx_mapping
[
i
],
starts
[
i
]
:
ends
[
i
]
]
@
triton
.
jit
def
_prepare_prefill_inputs_kernel
(
input_ids_ptr
,
next_prefill_tokens_ptr
,
idx_mapping_ptr
,
query_start_loc_ptr
,
prefill_token_ids_ptr
,
prefill_token_ids_stride
,
prefill_lens_ptr
,
num_computed_tokens_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
prefill_len
=
tl
.
load
(
prefill_lens_ptr
+
req_state_idx
)
num_computed
=
tl
.
load
(
num_computed_tokens_ptr
+
req_state_idx
)
if
num_computed
>=
prefill_len
:
# Not prefill.
return
query_start
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
query_len
=
query_end
-
query_start
prefill_ptr
=
prefill_token_ids_ptr
+
req_state_idx
*
prefill_token_ids_stride
for
i
in
range
(
0
,
query_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
query_len
tokens
=
tl
.
load
(
prefill_ptr
+
num_computed
+
block
,
mask
=
mask
)
tl
.
store
(
input_ids_ptr
+
query_start
+
block
,
tokens
,
mask
=
mask
)
next_pos
=
num_computed
+
query_len
if
next_pos
<
prefill_len
:
next_token
=
tl
.
load
(
prefill_ptr
+
next_pos
)
tl
.
store
(
next_prefill_tokens_ptr
+
req_state_idx
,
next_token
)
def
prepare_prefill_inputs
(
idx_mapping
:
np
.
ndarray
,
num_scheduled_tokens
:
np
.
ndarray
,
query_start_loc
:
np
.
ndarray
,
prefill_token_ids
:
np
.
ndarray
,
num_computed_prefill_tokens
:
np
.
ndarray
,
input_ids
:
np
.
ndarray
,
input_ids
:
torch
.
Tensor
,
next_prefill_tokens
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
prefill_token_ids
:
torch
.
Tensor
,
prefill_len
:
torch
.
Tensor
,
num_computed_tokens
:
torch
.
Tensor
,
)
->
None
:
_prepare_prefill_inputs
(
num_reqs
=
idx_mapping
.
shape
[
0
]
_prepare_prefill_inputs_kernel
[(
num_reqs
,)](
input_ids
,
next_prefill_tokens
,
idx_mapping
,
num_scheduled_tokens
,
query_start_loc
,
prefill_token_ids
,
num_computed_prefill_tokens
,
input_ids
,
prefill_token_ids
.
stride
(
0
),
prefill_len
,
num_computed_tokens
,
BLOCK_SIZE
=
1024
,
)
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
ca1b1e72
...
...
@@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
self
.
use_async_scheduling
:
self
.
input_prep_event
=
torch
.
cuda
.
Event
()
self
.
structured_outputs_event
=
torch
.
cuda
.
Event
()
self
.
spec_decode_event
=
torch
.
cuda
.
Event
()
else
:
self
.
input_prep_event
=
None
self
.
structured_outputs_event
=
None
self
.
spec_decode_event
=
None
if
self
.
speculative_config
is
not
None
:
self
.
do_spec_decode
=
True
...
...
@@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_new_blocks
[
i
].
append
(
x
+
len
(
block_ids
))
new_block_ids
[
i
].
extend
(
block_ids
)
overwrite
.
append
(
True
)
# Update the GPU tensors for request states.
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
prefill_len
.
copy_to_gpu
()
# Add new blocks for the existing requests.
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
...
...
@@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc_cpu
=
self
.
input_buffers
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
query_start_loc_np
=
self
.
input_buffers
.
query_start_loc
.
np
[:
num_reqs
+
1
]
#
Copy
prefill tokens
from CPU to GPU
.
#
Get
prefill tokens.
prepare_prefill_inputs
(
idx_mapping_np
,
num_scheduled_tokens
,
query_start_loc_np
,
self
.
req_states
.
prefill_token_ids
.
np
,
self
.
req_states
.
num_computed_prefill_tokens
,
self
.
input_buffers
.
input_ids
.
np
,
self
.
input_buffers
.
input_ids
,
self
.
req_states
.
next_prefill_tokens
,
idx_mapping
,
query_start_loc_gpu
,
self
.
req_states
.
prefill_token_ids
.
gpu
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
num_computed_tokens
,
)
self
.
input_buffers
.
input_ids
.
copy_to_gpu
(
num_tokens
)
# Prepare positions and seq_lens.
prepare_pos_seq_lens
(
...
...
@@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices
=
combine_sampled_and_draft_tokens
(
self
.
input_buffers
.
input_ids
.
gpu
,
self
.
input_buffers
.
input_ids
,
idx_mapping
,
self
.
req_states
.
last_sampled_tokens
,
query_start_loc_gpu
,
...
...
@@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config
=
self
.
kv_cache_config
,
)
input_ids
=
self
.
input_buffers
.
input_ids
.
gpu
[:
num_tokens_after_padding
]
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
]
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
]
return
InputBatch
(
req_ids
=
req_ids
,
...
...
@@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_reqs
=
input_batch
.
num_reqs
idx_mapping_np
=
input_batch
.
idx_mapping_np
with
async_barrier
(
self
.
spec_decode_event
):
self
.
input_buffers
.
next_prefill_tokens
.
np
[:
num_reqs
]
=
(
self
.
req_states
.
prefill_token_ids
.
np
[
idx_mapping_np
,
self
.
req_states
.
num_computed_prefill_tokens
[
idx_mapping_np
],
]
)
next_prefill_tokens
=
self
.
input_buffers
.
next_prefill_tokens
.
copy_to_gpu
(
num_reqs
)
assert
self
.
speculator
is
not
None
last_sampled_tokens
=
self
.
req_states
.
last_sampled_tokens
[
input_batch
.
idx_mapping
]
next_prefill_tokens
=
self
.
req_states
.
next_prefill_tokens
[
input_batch
.
idx_mapping
]
draft_tokens
=
self
.
speculator
.
propose
(
input_batch
,
sampling_metadata
,
...
...
@@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
aux_hidden_states
,
num_sampled
,
num_rejected
,
self
.
req_states
.
last_sampled_tokens
,
last_sampled_tokens
,
next_prefill_tokens
,
)
return
draft_tokens
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
ca1b1e72
...
...
@@ -121,7 +121,7 @@ class EagleSpeculator:
num_tokens_across_dp
=
num_tokens_across_dp
,
):
ret_hidden_states
=
self
.
model
(
input_ids
=
self
.
input_buffers
.
input_ids
.
gpu
[:
num_tokens
],
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_tokens
],
)
...
...
@@ -194,7 +194,7 @@ class EagleSpeculator:
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
# [
max_
num_reqs
, 1
]
# [num_reqs]
last_sampled
:
torch
.
Tensor
,
# [num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
...
...
@@ -316,7 +316,6 @@ def _prepare_eagle_inputs_kernel(
eagle_positions_ptr
,
target_input_ids_ptr
,
target_positions_ptr
,
idx_mapping_ptr
,
last_sampled_ptr
,
next_prefill_tokens_ptr
,
num_sampled_ptr
,
...
...
@@ -335,8 +334,7 @@ def _prepare_eagle_inputs_kernel(
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch_idx
)
if
num_sampled
>
0
:
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
next_token
=
tl
.
load
(
last_sampled_ptr
+
req_state_idx
).
to
(
tl
.
int32
)
next_token
=
tl
.
load
(
last_sampled_ptr
+
batch_idx
).
to
(
tl
.
int32
)
else
:
# Chunked prefilling.
# Get the next prefill token.
...
...
@@ -368,9 +366,9 @@ def prepare_eagle_inputs(
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
# [
max_
num_reqs
, 1
]
# [num_reqs]
last_sampled
:
torch
.
Tensor
,
# [
max_
num_reqs]
# [num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_reqs
=
input_batch
.
num_reqs
...
...
@@ -381,11 +379,10 @@ def prepare_eagle_inputs(
)
_prepare_eagle_inputs_kernel
[(
num_reqs
,)](
last_token_indices
,
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
input_ids
,
input_buffers
.
positions
,
input_batch
.
input_ids
,
input_batch
.
positions
,
input_batch
.
idx_mapping
,
last_sampled
,
next_prefill_tokens
,
num_sampled
,
...
...
@@ -485,7 +482,7 @@ def prepare_eagle_decode(
last_token_indices
,
target_seq_lens
,
num_rejected
,
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
input_ids
,
input_buffers
.
positions
,
input_hidden_states
,
input_hidden_states
.
stride
(
0
),
...
...
@@ -553,7 +550,7 @@ def update_eagle_inputs(
):
num_reqs
,
hidden_size
=
output_hidden_states
.
shape
_update_eagle_inputs_kernel
[(
num_reqs
,)](
input_buffers
.
input_ids
.
gpu
,
input_buffers
.
input_ids
,
input_buffers
.
positions
,
hidden_states
,
hidden_states
.
stride
(
0
),
...
...
vllm/v1/worker/gpu/states.py
View file @
ca1b1e72
...
...
@@ -117,8 +117,7 @@ class RequestState:
self
.
prefill_token_ids
=
UvaBuffer
(
self
.
max_num_reqs
,
self
.
max_model_len
,
dtype
=
torch
.
int32
)
self
.
prefill_len
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
prefill_len
=
UvaBuffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
# Number of computed tokens.
self
.
num_computed_prefill_tokens
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens
=
torch
.
zeros
(
...
...
@@ -140,6 +139,9 @@ class RequestState:
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
next_prefill_tokens
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# LoRA.
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
...
...
@@ -380,13 +382,13 @@ def _expand_sampling_metadata_kernel(
expanded_top_p_ptr
,
top_k_ptr
,
expanded_top_k_ptr
,
seeds_ptr
,
rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
freq_penalty_ptr
,
expanded_freq_penalty_ptr
,
pres_penalty_ptr
,
expanded_pres_penalty_ptr
,
seeds_ptr
,
expanded_seeds_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment