Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
04bf5a35
Unverified
Commit
04bf5a35
authored
Mar 16, 2026
by
Fynn Schmitt-Ulms
Committed by
GitHub
Mar 16, 2026
Browse files
[Spec Decode] Update extract_hidden_states to use deferred kv_connector clear (#37013)
parent
43a73f85
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
70 deletions
+35
-70
tests/v1/spec_decode/test_extract_hidden_states.py
tests/v1/spec_decode/test_extract_hidden_states.py
+21
-33
vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
...ansfer/kv_connector/v1/example_hidden_states_connector.py
+3
-1
vllm/v1/spec_decode/extract_hidden_states.py
vllm/v1/spec_decode/extract_hidden_states.py
+10
-24
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-12
No files found.
tests/v1/spec_decode/test_extract_hidden_states.py
View file @
04bf5a35
...
...
@@ -252,29 +252,22 @@ def test_propose():
]
# Sampled token IDs from target model
sampled_token_ids
=
torch
.
tensor
([
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
)
# Mock scheduler output
mock_scheduler_output
=
mock
.
MagicMock
()
sampled_token_ids
=
torch
.
tensor
(
[
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
).
unsqueeze
(
-
1
)
# Call propose
with
mock
.
patch
(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
)
as
mock_has_kv
:
mock_has_kv
.
return_value
=
False
draft_tokens
,
kv_connector_output
=
proposer
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
mock_scheduler_output
,
slot_mappings
=
None
,
)
draft_tokens
=
proposer
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
slot_mappings
=
None
,
)
# Verify draft tokens match sampled tokens
# Shape should be [batch_size, 1] for num_speculative_tokens=1
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
torch
.
equal
(
draft_tokens
[:,
0
]
,
sampled_token_ids
)
assert
torch
.
equal
(
draft_tokens
,
sampled_token_ids
)
# Verify the model was called
model_mock
.
assert_called_once
()
...
...
@@ -326,21 +319,16 @@ def test_propose_different_layer_counts(num_hidden_layers):
for
_
in
range
(
num_hidden_layers
)
]
sampled_token_ids
=
torch
.
tensor
([
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
)
mock_scheduler_output
=
mock
.
MagicMock
()
with
mock
.
patch
(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
)
as
mock_has_kv
:
mock_has_kv
.
return_value
=
False
draft_tokens
,
_
=
proposer
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
mock_scheduler_output
,
slot_mappings
=
None
,
)
sampled_token_ids
=
torch
.
tensor
(
[
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
).
unsqueeze
(
-
1
)
draft_tokens
=
proposer
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
slot_mappings
=
None
,
)
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
torch
.
equal
(
draft_tokens
[:,
0
]
,
sampled_token_ids
)
assert
torch
.
equal
(
draft_tokens
,
sampled_token_ids
)
vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
View file @
04bf5a35
...
...
@@ -286,7 +286,9 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1):
cached_req
=
self
.
_active_requests
[
req_id
]
req_block_ids
=
self
.
_req_blocks
[
req_id
]
assert
new_block_ids
is
not
None
if
new_block_ids
is
None
:
continue
block_ids
=
new_block_ids
[
0
]
req_block_ids
.
extend
(
block_ids
)
...
...
vllm/v1/spec_decode/extract_hidden_states.py
View file @
04bf5a35
...
...
@@ -3,26 +3,21 @@
from
__future__
import
annotations
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer
import
has_kv_transfer_group
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.model_loader
import
get_model
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
,
CommonAttentionMetadata
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
PADDING_SLOT_ID
=
-
1
...
...
@@ -79,11 +74,10 @@ class ExtractHiddenStatesProposer:
sampled_token_ids
:
torch
.
Tensor
,
target_hidden_states
:
list
[
torch
.
Tensor
],
common_attn_metadata
:
CommonAttentionMetadata
,
scheduler_output
:
SchedulerOutput
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
KVConnectorOutput
|
None
]
:
)
->
torch
.
Tensor
:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
...
...
@@ -99,7 +93,6 @@ class ExtractHiddenStatesProposer:
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
...
...
@@ -136,22 +129,15 @@ class ExtractHiddenStatesProposer:
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
with
(
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
),
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
),
(
KVConnectorModelRunnerMixin
.
_get_kv_connector_output
(
scheduler_output
)
if
has_kv_transfer_group
()
else
nullcontext
()
)
as
kv_connector_output
,
):
self
.
model
(
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
...
...
@@ -159,7 +145,7 @@ class ExtractHiddenStatesProposer:
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return
sampled_token_ids
.
unsqueeze
(
-
1
),
kv_connector_output
return
sampled_token_ids
def
_get_slot_mapping
(
self
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
04bf5a35
...
...
@@ -4328,23 +4328,12 @@ class GPUModelRunner(
)
target_hidden_states
=
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
]
draft_token_ids
,
drafter_kv_connector_output
=
self
.
drafter
.
propose
(
draft_token_ids
=
self
.
drafter
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
scheduler_output
,
slot_mappings
=
slot_mappings
,
)
# Combine KVConnectorOutputs or select the non-empty one
if
self
.
kv_connector_output
and
drafter_kv_connector_output
:
self
.
kv_connector_output
=
KVConnectorOutput
.
merge
(
self
.
kv_connector_output
,
drafter_kv_connector_output
)
else
:
self
.
kv_connector_output
=
(
self
.
kv_connector_output
or
drafter_kv_connector_output
)
next_token_ids
,
valid_sampled_tokens_count
=
(
self
.
drafter
.
prepare_next_token_ids_padded
(
common_attn_metadata
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment