Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
963dc0b8
Unverified
Commit
963dc0b8
authored
Jan 17, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 17, 2026
Browse files
[Model Runner V2] Minor optimization for eagle input processing (#32535)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
8cc26acd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
14 deletions
+12
-14
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+2
-8
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+10
-6
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
963dc0b8
...
@@ -827,20 +827,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -827,20 +827,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_rejected
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
self
.
speculator
is
not
None
assert
self
.
speculator
is
not
None
last_sampled_tokens
=
self
.
req_states
.
last_sampled_tokens
[
input_batch
.
idx_mapping
]
next_prefill_tokens
=
self
.
req_states
.
next_prefill_tokens
[
input_batch
.
idx_mapping
]
draft_tokens
=
self
.
speculator
.
propose
(
draft_tokens
=
self
.
speculator
.
propose
(
input_batch
,
input_batch
,
last_hidden_states
,
last_hidden_states
,
aux_hidden_states
,
aux_hidden_states
,
num_sampled
,
num_sampled
,
num_rejected
,
num_rejected
,
last_sampled_tokens
,
self
.
req_states
.
last_sampled_tokens
,
next_prefill_tokens
,
self
.
req_states
.
next_prefill_tokens
,
self
.
sampler
.
sampling_states
.
temperature
.
gpu
,
self
.
sampler
.
sampling_states
.
temperature
.
gpu
,
self
.
sampler
.
sampling_states
.
seeds
.
gpu
,
self
.
sampler
.
sampling_states
.
seeds
.
gpu
,
)
)
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
963dc0b8
...
@@ -195,9 +195,9 @@ class EagleSpeculator:
...
@@ -195,9 +195,9 @@ class EagleSpeculator:
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
# [num_reqs]
# [
max_
num_reqs]
last_sampled
:
torch
.
Tensor
,
last_sampled
:
torch
.
Tensor
,
# [num_reqs]
# [
max_
num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
next_prefill_tokens
:
torch
.
Tensor
,
# [max_num_reqs]
# [max_num_reqs]
temperature
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
...
@@ -320,6 +320,7 @@ def _prepare_eagle_inputs_kernel(
...
@@ -320,6 +320,7 @@ def _prepare_eagle_inputs_kernel(
eagle_positions_ptr
,
eagle_positions_ptr
,
target_input_ids_ptr
,
target_input_ids_ptr
,
target_positions_ptr
,
target_positions_ptr
,
idx_mapping_ptr
,
last_sampled_ptr
,
last_sampled_ptr
,
next_prefill_tokens_ptr
,
next_prefill_tokens_ptr
,
num_sampled_ptr
,
num_sampled_ptr
,
...
@@ -328,6 +329,8 @@ def _prepare_eagle_inputs_kernel(
...
@@ -328,6 +329,8 @@ def _prepare_eagle_inputs_kernel(
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
batch_idx
=
tl
.
program_id
(
0
)
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
query_len
=
query_end
-
query_start
query_len
=
query_end
-
query_start
...
@@ -338,11 +341,11 @@ def _prepare_eagle_inputs_kernel(
...
@@ -338,11 +341,11 @@ def _prepare_eagle_inputs_kernel(
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch_idx
)
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch_idx
)
if
num_sampled
>
0
:
if
num_sampled
>
0
:
next_token
=
tl
.
load
(
last_sampled_ptr
+
batch
_idx
).
to
(
tl
.
int32
)
next_token
=
tl
.
load
(
last_sampled_ptr
+
req_state
_idx
).
to
(
tl
.
int32
)
else
:
else
:
# Chunked prefilling.
# Chunked prefilling.
# Get the next prefill token.
# Get the next prefill token.
next_token
=
tl
.
load
(
next_prefill_tokens_ptr
+
batch
_idx
)
next_token
=
tl
.
load
(
next_prefill_tokens_ptr
+
req_state
_idx
)
# Shift target_input_ids by one.
# Shift target_input_ids by one.
for
i
in
range
(
1
,
query_len
,
BLOCK_SIZE
):
for
i
in
range
(
1
,
query_len
,
BLOCK_SIZE
):
...
@@ -370,9 +373,9 @@ def prepare_eagle_inputs(
...
@@ -370,9 +373,9 @@ def prepare_eagle_inputs(
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
# [num_reqs]
# [
max_
num_reqs]
last_sampled
:
torch
.
Tensor
,
last_sampled
:
torch
.
Tensor
,
# [num_reqs]
# [
max_
num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
next_prefill_tokens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_reqs
=
input_batch
.
num_reqs
num_reqs
=
input_batch
.
num_reqs
...
@@ -387,6 +390,7 @@ def prepare_eagle_inputs(
...
@@ -387,6 +390,7 @@ def prepare_eagle_inputs(
input_buffers
.
positions
,
input_buffers
.
positions
,
input_batch
.
input_ids
,
input_batch
.
input_ids
,
input_batch
.
positions
,
input_batch
.
positions
,
input_batch
.
idx_mapping
,
last_sampled
,
last_sampled
,
next_prefill_tokens
,
next_prefill_tokens
,
num_sampled
,
num_sampled
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment