Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e1da249c
Unverified
Commit
e1da249c
authored
Jan 21, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 21, 2026
Browse files
[Model Runner V2] Minor refactor for `compute_slot_mappings` (#32794)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
9b693d02
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
22 deletions
+33
-22
vllm/v1/worker/gpu/block_table.py
vllm/v1/worker/gpu/block_table.py
+22
-18
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+3
-1
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+8
-3
No files found.
vllm/v1/worker/gpu/block_table.py
View file @
e1da249c
...
...
@@ -116,24 +116,26 @@ class BlockTables:
def
compute_slot_mappings
(
self
,
idx_mapping
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_reqs
=
query_start_loc
.
shape
[
0
]
-
1
num_reqs
=
idx_mapping
.
shape
[
0
]
num_tokens
=
positions
.
shape
[
0
]
num_groups
=
self
.
num_kv_cache_groups
_compute_slot_mappings_kernel
[(
num_groups
,
num_reqs
+
1
)](
num_tokens
,
self
.
max_num_batched_tokens
,
idx_mapping
,
query_start_loc
,
positions
,
self
.
input_
block_table_ptrs
,
self
.
block_table_ptrs
,
self
.
block_table_strides
,
self
.
block_sizes_tensor
,
self
.
slot_mappings
,
self
.
slot_mappings
.
stride
(
0
),
PAD_ID
=
PAD_SLOT_ID
,
BLOCK_SIZE
=
1024
,
# type: ignore
TRITON_
BLOCK_SIZE
=
1024
,
# type: ignore
)
return
self
.
slot_mappings
[:,
:
num_tokens
]
...
...
@@ -176,42 +178,44 @@ def _gather_block_tables_kernel(
def
_compute_slot_mappings_kernel
(
num_tokens
,
max_num_tokens
,
cu_num_tokens
,
# [num_reqs + 1]
idx_mapping
,
# [num_reqs]
query_start_loc
,
# [num_reqs + 1]
pos
,
# [num_tokens]
block_table_ptrs
,
# [num_kv_cache_groups]
block_table_strides
,
# [num_kv_cache_groups]
page
_sizes
,
# [num_kv_cache_groups]
block
_sizes
,
# [num_kv_cache_groups]
slot_mappings_ptr
,
# [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride
,
PAD_ID
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
TRITON_
BLOCK_SIZE
:
tl
.
constexpr
,
):
# kv cache group id
group_id
=
tl
.
program_id
(
0
)
req
_idx
=
tl
.
program_id
(
1
)
batch
_idx
=
tl
.
program_id
(
1
)
slot_mapping_ptr
=
slot_mappings_ptr
+
group_id
*
slot_mappings_stride
if
req
_idx
==
tl
.
num_programs
(
1
)
-
1
:
if
batch
_idx
==
tl
.
num_programs
(
1
)
-
1
:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for
i
in
range
(
num_tokens
,
max_num_tokens
,
BLOCK_SIZE
):
offset
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
for
i
in
range
(
num_tokens
,
max_num_tokens
,
TRITON_
BLOCK_SIZE
):
offset
=
i
+
tl
.
arange
(
0
,
TRITON_
BLOCK_SIZE
)
tl
.
store
(
slot_mapping_ptr
+
offset
,
PAD_ID
,
mask
=
offset
<
max_num_tokens
)
return
block_table_ptr
=
_load_ptr
(
block_table_ptrs
+
group_id
,
tl
.
int32
)
block_table_stride
=
tl
.
load
(
block_table_strides
+
group_id
)
page
_size
=
tl
.
load
(
page
_sizes
+
group_id
)
block
_size
=
tl
.
load
(
block
_sizes
+
group_id
)
start_idx
=
tl
.
load
(
cu_num_tokens
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_tokens
+
req_idx
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
BLOCK_SIZE
):
offset
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
req_state_idx
=
tl
.
load
(
idx_mapping
+
batch_idx
)
start_idx
=
tl
.
load
(
query_start_loc
+
batch_idx
)
end_idx
=
tl
.
load
(
query_start_loc
+
batch_idx
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
TRITON_BLOCK_SIZE
):
offset
=
i
+
tl
.
arange
(
0
,
TRITON_BLOCK_SIZE
)
positions
=
tl
.
load
(
pos
+
offset
,
mask
=
offset
<
end_idx
,
other
=
0
)
block_indices
=
positions
//
page
_size
block_indices
=
positions
//
block
_size
block_numbers
=
tl
.
load
(
block_table_ptr
+
req_idx
*
block_table_stride
+
block_indices
block_table_ptr
+
req_
state_
idx
*
block_table_stride
+
block_indices
)
slot_ids
=
block_numbers
*
page
_size
+
positions
%
page
_size
slot_ids
=
block_numbers
*
block
_size
+
positions
%
block
_size
tl
.
store
(
slot_mapping_ptr
+
offset
,
slot_ids
,
mask
=
offset
<
end_idx
)
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
e1da249c
...
...
@@ -607,7 +607,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc
,
self
.
input_buffers
.
positions
[:
num_tokens
]
idx_mapping
,
query_start_loc
,
self
.
input_buffers
.
positions
[:
num_tokens
],
)
# Layer name -> attention metadata.
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
e1da249c
...
...
@@ -138,6 +138,7 @@ class EagleSpeculator:
)
->
None
:
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
idx_mapping
=
self
.
idx_mapping
[:
num_reqs
]
for
step
in
range
(
1
,
self
.
num_speculative_steps
):
# Run the eagle model.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
...
...
@@ -149,7 +150,7 @@ class EagleSpeculator:
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
logits
,
self
.
idx_mapping
[:
num_reqs
]
,
idx_mapping
,
self
.
temperature
,
self
.
seeds
,
pos
+
1
,
...
...
@@ -166,7 +167,9 @@ class EagleSpeculator:
self
.
hidden_states
,
self
.
max_model_len
,
)
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc
,
pos
)
self
.
block_tables
.
compute_slot_mappings
(
idx_mapping
,
query_start_loc
,
pos
)
def
capture_model
(
self
)
->
None
:
if
self
.
num_speculative_steps
==
1
:
...
...
@@ -279,7 +282,9 @@ class EagleSpeculator:
self
.
max_num_reqs
,
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc
,
pos
)
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
idx_mapping
,
query_start_loc
,
pos
)
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
num_reqs
)
if
cudagraph_size
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment