Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
04bf5a35
Unverified
Commit
04bf5a35
authored
Mar 16, 2026
by
Fynn Schmitt-Ulms
Committed by
GitHub
Mar 16, 2026
Browse files
[Spec Decode] Update extract_hidden_states to use deferred kv_connector clear (#37013)
parent
43a73f85
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
70 deletions
+35
-70
tests/v1/spec_decode/test_extract_hidden_states.py
tests/v1/spec_decode/test_extract_hidden_states.py
+21
-33
vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
...ansfer/kv_connector/v1/example_hidden_states_connector.py
+3
-1
vllm/v1/spec_decode/extract_hidden_states.py
vllm/v1/spec_decode/extract_hidden_states.py
+10
-24
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-12
No files found.
tests/v1/spec_decode/test_extract_hidden_states.py
View file @
04bf5a35
...
@@ -252,29 +252,22 @@ def test_propose():
...
@@ -252,29 +252,22 @@ def test_propose():
]
]
# Sampled token IDs from target model
# Sampled token IDs from target model
sampled_token_ids
=
torch
.
tensor
([
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
)
sampled_token_ids
=
torch
.
tensor
(
[
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
# Mock scheduler output
).
unsqueeze
(
-
1
)
mock_scheduler_output
=
mock
.
MagicMock
()
# Call propose
# Call propose
with
mock
.
patch
(
draft_tokens
=
proposer
.
propose
(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
)
as
mock_has_kv
:
mock_has_kv
.
return_value
=
False
draft_tokens
,
kv_connector_output
=
proposer
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
mock_scheduler_output
,
slot_mappings
=
None
,
slot_mappings
=
None
,
)
)
# Verify draft tokens match sampled tokens
# Verify draft tokens match sampled tokens
# Shape should be [batch_size, 1] for num_speculative_tokens=1
# Shape should be [batch_size, 1] for num_speculative_tokens=1
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
torch
.
equal
(
draft_tokens
[:,
0
]
,
sampled_token_ids
)
assert
torch
.
equal
(
draft_tokens
,
sampled_token_ids
)
# Verify the model was called
# Verify the model was called
model_mock
.
assert_called_once
()
model_mock
.
assert_called_once
()
...
@@ -326,21 +319,16 @@ def test_propose_different_layer_counts(num_hidden_layers):
...
@@ -326,21 +319,16 @@ def test_propose_different_layer_counts(num_hidden_layers):
for
_
in
range
(
num_hidden_layers
)
for
_
in
range
(
num_hidden_layers
)
]
]
sampled_token_ids
=
torch
.
tensor
([
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
)
sampled_token_ids
=
torch
.
tensor
(
mock_scheduler_output
=
mock
.
MagicMock
()
[
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
).
unsqueeze
(
-
1
)
with
mock
.
patch
(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
)
as
mock_has_kv
:
mock_has_kv
.
return_value
=
False
draft_tokens
,
_
=
proposer
.
propose
(
draft_tokens
=
proposer
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
mock_scheduler_output
,
slot_mappings
=
None
,
slot_mappings
=
None
,
)
)
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
torch
.
equal
(
draft_tokens
[:,
0
]
,
sampled_token_ids
)
assert
torch
.
equal
(
draft_tokens
,
sampled_token_ids
)
vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
View file @
04bf5a35
...
@@ -286,7 +286,9 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1):
...
@@ -286,7 +286,9 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1):
cached_req
=
self
.
_active_requests
[
req_id
]
cached_req
=
self
.
_active_requests
[
req_id
]
req_block_ids
=
self
.
_req_blocks
[
req_id
]
req_block_ids
=
self
.
_req_blocks
[
req_id
]
assert
new_block_ids
is
not
None
if
new_block_ids
is
None
:
continue
block_ids
=
new_block_ids
[
0
]
block_ids
=
new_block_ids
[
0
]
req_block_ids
.
extend
(
block_ids
)
req_block_ids
.
extend
(
block_ids
)
...
...
vllm/v1/spec_decode/extract_hidden_states.py
View file @
04bf5a35
...
@@ -3,26 +3,21 @@
...
@@ -3,26 +3,21 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer
import
has_kv_transfer_group
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
,
CommonAttentionMetadata
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
,
CommonAttentionMetadata
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
PADDING_SLOT_ID
=
-
1
PADDING_SLOT_ID
=
-
1
...
@@ -79,11 +74,10 @@ class ExtractHiddenStatesProposer:
...
@@ -79,11 +74,10 @@ class ExtractHiddenStatesProposer:
sampled_token_ids
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
target_hidden_states
:
list
[
torch
.
Tensor
],
target_hidden_states
:
list
[
torch
.
Tensor
],
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
scheduler_output
:
SchedulerOutput
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
KVConnectorOutput
|
None
]
:
)
->
torch
.
Tensor
:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
The ExtractHiddenStatesModel caches the hidden states in the KV cache
...
@@ -99,7 +93,6 @@ class ExtractHiddenStatesProposer:
...
@@ -99,7 +93,6 @@ class ExtractHiddenStatesProposer:
target_hidden_states: List of hidden state tensors from target model
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
interface compatibility)
...
@@ -136,8 +129,7 @@ class ExtractHiddenStatesProposer:
...
@@ -136,8 +129,7 @@ class ExtractHiddenStatesProposer:
if
num_tokens_across_dp
is
not
None
:
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
with
(
with
set_forward_context
(
set_forward_context
(
per_layer_attn_metadata
,
per_layer_attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
,
...
@@ -146,12 +138,6 @@ class ExtractHiddenStatesProposer:
...
@@ -146,12 +138,6 @@ class ExtractHiddenStatesProposer:
slot_mapping
=
self
.
_get_slot_mapping
(
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
num_input_tokens
,
common_attn_metadata
.
slot_mapping
),
),
),
(
KVConnectorModelRunnerMixin
.
_get_kv_connector_output
(
scheduler_output
)
if
has_kv_transfer_group
()
else
nullcontext
()
)
as
kv_connector_output
,
):
):
self
.
model
(
self
.
model
(
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
...
@@ -159,7 +145,7 @@ class ExtractHiddenStatesProposer:
...
@@ -159,7 +145,7 @@ class ExtractHiddenStatesProposer:
# Return the sampled tokens as "draft" tokens
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return
sampled_token_ids
.
unsqueeze
(
-
1
),
kv_connector_output
return
sampled_token_ids
def
_get_slot_mapping
(
def
_get_slot_mapping
(
self
,
self
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
04bf5a35
...
@@ -4328,23 +4328,12 @@ class GPUModelRunner(
...
@@ -4328,23 +4328,12 @@ class GPUModelRunner(
)
)
target_hidden_states
=
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
]
target_hidden_states
=
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
]
draft_token_ids
,
drafter_kv_connector_output
=
self
.
drafter
.
propose
(
draft_token_ids
=
self
.
drafter
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
scheduler_output
,
slot_mappings
=
slot_mappings
,
slot_mappings
=
slot_mappings
,
)
)
# Combine KVConnectorOutputs or select the non-empty one
if
self
.
kv_connector_output
and
drafter_kv_connector_output
:
self
.
kv_connector_output
=
KVConnectorOutput
.
merge
(
self
.
kv_connector_output
,
drafter_kv_connector_output
)
else
:
self
.
kv_connector_output
=
(
self
.
kv_connector_output
or
drafter_kv_connector_output
)
next_token_ids
,
valid_sampled_tokens_count
=
(
next_token_ids
,
valid_sampled_tokens_count
=
(
self
.
drafter
.
prepare_next_token_ids_padded
(
self
.
drafter
.
prepare_next_token_ids_padded
(
common_attn_metadata
,
common_attn_metadata
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment