Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9433acb8
Unverified
Commit
9433acb8
authored
Mar 02, 2026
by
Fynn Schmitt-Ulms
Committed by
GitHub
Mar 02, 2026
Browse files
[Spec Decode] Add hidden states extraction system (#33736)
Signed-off-by:
Fynn Schmitt-Ulms
<
fschmitt@redhat.com
>
parent
d1a6e96d
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
2102 additions
and
38 deletions
+2102
-38
examples/offline_inference/extract_hidden_states.py
examples/offline_inference/extract_hidden_states.py
+58
-0
tests/models/registry.py
tests/models/registry.py
+5
-1
tests/v1/kv_connector/extract_hidden_states_integration/__init__.py
...v_connector/extract_hidden_states_integration/__init__.py
+0
-0
tests/v1/kv_connector/extract_hidden_states_integration/predictable_llama.py
...or/extract_hidden_states_integration/predictable_llama.py
+120
-0
tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py
...ctor/extract_hidden_states_integration/test_extraction.py
+155
-0
tests/v1/spec_decode/test_extract_hidden_states.py
tests/v1/spec_decode/test_extract_hidden_states.py
+346
-0
vllm/config/speculative.py
vllm/config/speculative.py
+80
-26
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+4
-0
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+6
-0
vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
...ansfer/kv_connector/v1/example_hidden_states_connector.py
+354
-0
vllm/model_executor/models/extract_hidden_states.py
vllm/model_executor/models/extract_hidden_states.py
+394
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/configs/extract_hidden_states.py
vllm/transformers_utils/configs/extract_hidden_states.py
+53
-0
vllm/v1/outputs.py
vllm/v1/outputs.py
+53
-1
vllm/v1/spec_decode/extract_hidden_states.py
vllm/v1/spec_decode/extract_hidden_states.py
+395
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+78
-10
No files found.
examples/offline_inference/extract_hidden_states.py
0 → 100644
View file @
9433acb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
tempfile
from
safetensors
import
safe_open
from
vllm
import
LLM
,
SamplingParams
# Example: Using the custom "extract_hidden_states" speculator method and
# ExampleHiddenStatesConnector to extract and save hidden states from vllm
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
llm
=
LLM
(
model
=
"Qwen/Qwen3-8B"
,
# Your target model
speculative_config
=
{
"method"
:
"extract_hidden_states"
,
"num_speculative_tokens"
:
1
,
"draft_model_config"
:
{
"hf_config"
:
{
"eagle_aux_hidden_state_layer_ids"
:
[
# Target model layer indices
1
,
2
,
3
,
4
,
],
}
},
},
kv_transfer_config
=
{
"kv_connector"
:
"ExampleHiddenStatesConnector"
,
"kv_role"
:
"kv_producer"
,
"kv_connector_extra_config"
:
{
"shared_storage_path"
:
tmpdirname
,
},
},
)
prompts
=
[
"Generate a sentence with hidden states"
,
"Write a python function"
]
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
in
outputs
:
print
(
"
\n
Prompt:"
,
output
.
prompt
)
print
(
"Prompt token ids:"
,
output
.
prompt_token_ids
)
hidden_states_path
=
output
.
kv_transfer_params
.
get
(
"hidden_states_path"
)
assert
hidden_states_path
is
not
None
print
(
"Prompt hidden states path:"
,
hidden_states_path
)
with
safe_open
(
hidden_states_path
,
"pt"
)
as
f
:
token_ids
=
f
.
get_tensor
(
"token_ids"
)
hidden_states
=
f
.
get_tensor
(
"hidden_states"
)
print
(
"Extracted token ids:"
,
token_ids
)
# Matches prompt token ids
print
(
"Extracted hidden states shape:"
,
hidden_states
.
shape
)
# [num_hidden_layers, prompt len, hidden size]
print
(
"Extracted hidden states:"
,
hidden_states
)
tests/models/registry.py
View file @
9433acb8
...
@@ -1156,6 +1156,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
...
@@ -1156,6 +1156,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model
=
"LGAI-EXAONE/K-EXAONE-236B-A23B"
,
speculative_model
=
"LGAI-EXAONE/K-EXAONE-236B-A23B"
,
min_transformers_version
=
"5.1.0"
,
min_transformers_version
=
"5.1.0"
,
),
),
"ExtractHiddenStatesModel"
:
_HfExamplesInfo
(
"Qwen/Qwen3-8B"
,
speculative_method
=
"extract_hidden_states"
,
),
"Glm4MoeMTPModel"
:
_HfExamplesInfo
(
"Glm4MoeMTPModel"
:
_HfExamplesInfo
(
"zai-org/GLM-4.5"
,
"zai-org/GLM-4.5"
,
speculative_model
=
"zai-org/GLM-4.5"
,
speculative_model
=
"zai-org/GLM-4.5"
,
...
...
tests/v1/kv_connector/extract_hidden_states_integration/__init__.py
0 → 100644
View file @
9433acb8
tests/v1/kv_connector/extract_hidden_states_integration/predictable_llama.py
0 → 100644
View file @
9433acb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Predictable dummy model for testing extract_hidden_states.
Subclasses LlamaForCausalLM but overrides the model to produce deterministic
hidden states: layer i outputs values equal to (i).
"""
from
collections.abc
import
Iterable
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.sequence
import
IntermediateTensors
class
PredictableLlamaModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
aux_hidden_state_layers
=
tuple
[
int
,
...]()
# Create minimal embed_tokens for embedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
)
# Required for pipeline parallelism
from
vllm.model_executor.models.utils
import
(
make_empty_intermediate_tensors_factory
,
)
self
.
make_empty_intermediate_tensors
=
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
self
.
config
.
hidden_size
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Embed input IDs."""
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
extra_layer_kwargs
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
"""Forward pass that produces predictable outputs.
Returns:
If aux_hidden_state_layers is set: (hidden_states, aux_hidden_states)
Otherwise: hidden_states
"""
# Determine sequence length
if
inputs_embeds
is
not
None
:
seq_len
=
inputs_embeds
.
shape
[
0
]
device
=
inputs_embeds
.
device
elif
input_ids
is
not
None
:
seq_len
=
input_ids
.
shape
[
0
]
if
input_ids
.
ndim
==
1
else
input_ids
.
shape
[
-
1
]
device
=
input_ids
.
device
else
:
raise
ValueError
(
"Either input_ids or inputs_embeds must be provided"
)
# Final hidden states (last layer value)
hidden_states
=
torch
.
full
(
(
seq_len
,
self
.
config
.
hidden_size
),
fill_value
=
float
(
self
.
config
.
num_hidden_layers
),
device
=
device
,
dtype
=
torch
.
bfloat16
,
)
# Check if we need auxiliary hidden states
if
len
(
self
.
aux_hidden_state_layers
)
>
0
:
aux_hidden_states
=
[]
for
layer_idx
in
self
.
aux_hidden_state_layers
:
# Fill with (layer_idx) for predictability
layer_hidden
=
torch
.
full
(
(
seq_len
,
self
.
config
.
hidden_size
),
fill_value
=
float
(
layer_idx
),
device
=
device
,
dtype
=
torch
.
bfloat16
,
)
aux_hidden_states
.
append
(
layer_hidden
)
return
hidden_states
,
aux_hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
"""Skip weight loading."""
return
set
()
class
PredictableLlamaForCausalLM
(
LlamaForCausalLM
):
"""Predictable Llama model for testing.
Overrides _init_model to use PredictableLlamaModel instead of LlamaModel.
"""
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
layer_type
:
type
[
nn
.
Module
]
|
None
=
None
,
):
"""Initialize with predictable model."""
return
PredictableLlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
"""Skip weight loading for dummy model."""
return
set
()
tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py
0 → 100644
View file @
9433acb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
os
import
pytest
import
torch
from
safetensors
import
safe_open
from
vllm
import
LLM
,
ModelRegistry
,
SamplingParams
def
get_and_check_output
(
output
,
expected_shape
):
assert
output
.
kv_transfer_params
is
not
None
hidden_states_path
=
output
.
kv_transfer_params
.
get
(
"hidden_states_path"
)
assert
hidden_states_path
is
not
None
assert
os
.
path
.
exists
(
hidden_states_path
)
# Load and verify the saved tensors
with
safe_open
(
hidden_states_path
,
"pt"
)
as
f
:
# Check that token_ids and hidden_states are present
tensor_names
=
f
.
keys
()
assert
"token_ids"
in
tensor_names
assert
"hidden_states"
in
tensor_names
token_ids
=
f
.
get_tensor
(
"token_ids"
)
hidden_states
=
f
.
get_tensor
(
"hidden_states"
)
prompt_token_ids
=
output
.
prompt_token_ids
assert
torch
.
equal
(
token_ids
,
torch
.
tensor
(
prompt_token_ids
))
assert
hidden_states
.
shape
==
expected_shape
# Verify hidden_states are not all zeros (i.e., they were actually computed)
assert
not
torch
.
allclose
(
hidden_states
,
torch
.
zeros_like
(
hidden_states
))
return
token_ids
,
hidden_states
@
pytest
.
fixture
(
scope
=
"module"
)
def
predictable_llama_config_path
(
tmp_path_factory
):
"""Create a minimal LlamaConfig for PredictableLlamaForCausalLM."""
from
transformers
import
LlamaConfig
,
LlamaTokenizerFast
config_dir
=
tmp_path_factory
.
mktemp
(
"predictable_llama"
)
# Create a minimal Llama config with small dimensions
config
=
LlamaConfig
(
vocab_size
=
1000
,
hidden_size
=
256
,
intermediate_size
=
512
,
num_hidden_layers
=
24
,
# Enough layers to test various layer_ids
num_attention_heads
=
4
,
num_key_value_heads
=
4
,
max_position_embeddings
=
128
,
architectures
=
[
"PredictableLlamaForCausalLM"
],
)
# Save config
config
.
save_pretrained
(
config_dir
)
# Create a simple tokenizer
tokenizer
=
LlamaTokenizerFast
.
from_pretrained
(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
,
cache_dir
=
os
.
path
.
expanduser
(
"~/.cache/huggingface"
),
)
tokenizer
.
save_pretrained
(
config_dir
)
return
str
(
config_dir
)
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
def
register_predictable_model
():
"""Register the PredictableLlamaForCausalLM model."""
from
.predictable_llama
import
PredictableLlamaForCausalLM
if
"PredictableLlamaForCausalLM"
not
in
ModelRegistry
.
get_supported_archs
():
ModelRegistry
.
register_model
(
"PredictableLlamaForCausalLM"
,
PredictableLlamaForCausalLM
)
yield
def
test_extract_hidden_states_with_predictable_dummy_model
(
predictable_llama_config_path
,
tmp_path
):
"""Comprehensive test using a predictable dummy model with synthetic weights.
The PredictableLlamaForCausalLM outputs deterministic hidden states where
each layer produces values equal to (layer_index). This test verifies:
1. Hidden states are correctly extracted from requested layers
2. Values match the expected predictable pattern
3. Layer ordering is preserved correctly (non-sequential layer IDs)
4. Multiple prompts of different lengths produce consistent layer values
"""
# Test with non-sequential layer ordering to verify correct association
layer_ids
=
[
5
,
2
,
10
]
num_layers
=
len
(
layer_ids
)
llm
=
LLM
(
model
=
predictable_llama_config_path
,
speculative_config
=
{
"method"
:
"extract_hidden_states"
,
"num_speculative_tokens"
:
1
,
"draft_model_config"
:
{
"hf_config"
:
{
"eagle_aux_hidden_state_layer_ids"
:
layer_ids
}
},
},
kv_transfer_config
=
{
"kv_connector"
:
"ExampleHiddenStatesConnector"
,
"kv_role"
:
"kv_producer"
,
"kv_connector_extra_config"
:
{
"shared_storage_path"
:
tmp_path
},
},
max_model_len
=
128
,
enforce_eager
=
True
,
trust_remote_code
=
True
,
load_format
=
"dummy"
,
# Don't try to load real weights
)
# Test with multiple prompts of different lengths
prompts
=
[
"Short"
,
"Medium length"
,
"Much longer prompt with many tokens"
,
"Much longer prompt with many tokens"
,
# repeated prompt
]
sampling_params
=
SamplingParams
(
max_tokens
=
1
,
temperature
=
0.0
)
hidden_size
=
llm
.
llm_engine
.
model_config
.
get_hidden_size
()
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
del
llm
gc
.
collect
()
assert
len
(
outputs
)
==
len
(
prompts
)
for
output
in
outputs
:
# hidden_states shape is [prompt_len, num_hidden_layers, hidden_size]
expected_shape
=
(
len
(
output
.
prompt_token_ids
),
num_layers
,
hidden_size
,
)
_token_ids
,
hidden_states
=
get_and_check_output
(
output
,
expected_shape
)
for
idx
,
layer_id
in
enumerate
(
layer_ids
):
layer_hidden
=
hidden_states
[:,
idx
,
:]
assert
torch
.
allclose
(
layer_hidden
,
torch
.
full_like
(
layer_hidden
,
layer_id
),
atol
=
1e-5
,
),
(
f
"Layer
{
layer_id
}
at position
{
idx
}
should output
{
float
(
layer_id
)
}
, "
f
"but got mean=
{
layer_hidden
.
mean
():.
3
f
}
, "
f
"min=
{
layer_hidden
.
min
():.
3
f
}
, max=
{
layer_hidden
.
max
():.
3
f
}
"
)
tests/v1/spec_decode/test_extract_hidden_states.py
0 → 100644
View file @
9433acb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest
import
mock
import
pytest
import
torch
from
tests.v1.attention.utils
import
(
BatchSpec
,
create_common_attn_metadata
,
)
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
,
)
from
vllm.config.load
import
LoadConfig
from
vllm.platforms
import
current_platform
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
model_dir
=
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
def
_create_proposer
(
num_speculative_tokens
:
int
=
1
,
layer_ids
:
list
[
int
]
|
None
=
None
,
)
->
ExtractHiddenStatesProposer
:
"""Create an ExtractHiddenStatesProposer for testing."""
if
layer_ids
is
None
:
layer_ids
=
[
1
,
2
,
3
,
4
]
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
speculative_config
=
SpeculativeConfig
(
target_model_config
=
model_config
,
target_parallel_config
=
ParallelConfig
(),
method
=
"extract_hidden_states"
,
num_speculative_tokens
=
num_speculative_tokens
,
draft_model_config
=
{
"hf_config"
:
{
"eagle_aux_hidden_state_layer_ids"
:
layer_ids
,
}
},
)
device
=
current_platform
.
device_type
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
speculative_config
=
speculative_config
,
device_config
=
DeviceConfig
(
device
=
device
),
parallel_config
=
ParallelConfig
(),
load_config
=
LoadConfig
(),
scheduler_config
=
SchedulerConfig
(
max_model_len
=
model_config
.
max_model_len
,
is_encoder_decoder
=
model_config
.
is_encoder_decoder
,
),
attention_config
=
AttentionConfig
(),
)
return
ExtractHiddenStatesProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
def
test_proposer_initialization
():
"""Test that the proposer initializes correctly with the right parameters."""
layer_ids
=
[
1
,
2
,
3
,
4
]
proposer
=
_create_proposer
(
num_speculative_tokens
=
1
,
layer_ids
=
layer_ids
)
assert
proposer
.
num_hidden_states
==
len
(
layer_ids
)
assert
proposer
.
vllm_config
.
speculative_config
is
not
None
assert
proposer
.
vllm_config
.
speculative_config
.
num_speculative_tokens
==
1
# Verify the hidden states buffer is correctly shaped
expected_shape
=
(
proposer
.
max_num_tokens
,
len
(
layer_ids
),
proposer
.
hidden_size
,
)
assert
proposer
.
hidden_states
.
shape
==
expected_shape
def
test_proposer_initialization_missing_layer_ids
():
"""Test that initialization fails when layer_ids are not provided."""
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
speculative_config
=
SpeculativeConfig
(
target_model_config
=
model_config
,
target_parallel_config
=
ParallelConfig
(),
method
=
"extract_hidden_states"
,
num_speculative_tokens
=
1
,
draft_model_config
=
{
"hf_config"
:
{}
# Missing eagle_aux_hidden_state_layer_ids
},
)
device
=
current_platform
.
device_type
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
speculative_config
=
speculative_config
,
device_config
=
DeviceConfig
(
device
=
device
),
parallel_config
=
ParallelConfig
(),
load_config
=
LoadConfig
(),
scheduler_config
=
SchedulerConfig
(
max_model_len
=
model_config
.
max_model_len
,
is_encoder_decoder
=
model_config
.
is_encoder_decoder
,
),
attention_config
=
AttentionConfig
(),
)
with
pytest
.
raises
(
ValueError
,
match
=
"eagle_aux_hidden_state_layer_ids must be set"
):
ExtractHiddenStatesProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
def
test_prepare_next_token_ids_padded
():
"""
Test for prepare_next_token_ids_padded with extract_hidden_states.
Since num_speculative_tokens == 1, sampled_token_ids has shape (batch_size, 1).
For each request we either use the sampled token (if valid and not discarded)
or a backup token from the request state.
"""
device
=
torch
.
device
(
current_platform
.
device_type
)
num_requests
=
4
batch_spec
=
BatchSpec
(
seq_lens
=
[
5
]
*
num_requests
,
query_lens
=
[
5
]
*
num_requests
,
)
req_ids
=
[
f
"req_
{
i
+
1
}
"
for
i
in
range
(
num_requests
)]
mock_input_batch
=
mock
.
MagicMock
(
spec
=
InputBatch
)
mock_input_batch
.
req_ids
=
req_ids
mock_input_batch
.
num_reqs
=
num_requests
mock_input_batch
.
vocab_size
=
100
mock_requests
=
{}
for
req_id
in
req_ids
:
mock_request
=
mock
.
MagicMock
(
spec
=
CachedRequestState
)
# Each request will have a backup next token id of 10, 20, 30, 40
mock_request
.
get_token_id
.
return_value
=
int
(
req_id
.
split
(
"_"
)[
1
])
*
10
mock_requests
[
req_id
]
=
mock_request
# explicitly discard the last request
discarded_req_mask
=
torch
.
tensor
(
[
False
,
False
,
False
,
True
],
dtype
=
torch
.
bool
,
device
=
device
)
# With num_speculative_tokens=1, sampled_token_ids has shape [batch_size, 1]
sampled_token_ids
=
torch
.
tensor
(
[
[
1
],
# valid, use 1
[
4
],
# valid, use 4
[
-
1
],
# invalid, use backup token "30"
[
2
],
# explicitly discarded, use backup token "40"
],
dtype
=
torch
.
int32
,
device
=
device
,
)
expected_next_token_ids_cpu
=
[
1
,
4
,
30
,
40
]
expected_next_token_ids_tensor
=
torch
.
tensor
(
expected_next_token_ids_cpu
,
dtype
=
torch
.
int32
,
device
=
device
)
proposer
=
_create_proposer
(
num_speculative_tokens
=
1
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
,
)
# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
# It doesn't depend on whether the request is discarded
expected_valid_sampled_tokens_count
=
torch
.
tensor
(
[
1
,
1
,
0
,
1
],
dtype
=
torch
.
int32
,
device
=
device
)
next_token_ids
,
valid_sampled_tokens_count
=
proposer
.
prepare_next_token_ids_padded
(
common_attn_metadata
,
sampled_token_ids
,
mock_requests
,
mock_input_batch
,
discarded_req_mask
,
)
assert
torch
.
equal
(
next_token_ids
,
expected_next_token_ids_tensor
)
assert
torch
.
equal
(
valid_sampled_tokens_count
,
expected_valid_sampled_tokens_count
)
def
test_propose
():
"""
Test the propose() method of ExtractHiddenStatesProposer.
This should:
1. Accept target hidden states and sampled token IDs
2. Return the sampled tokens as "draft" tokens (shape [batch_size, 1])
3. Cache the hidden states in the model's KV cache
"""
device
=
torch
.
device
(
current_platform
.
device_type
)
# Setup test parameters
batch_size
=
2
num_tokens
=
5
num_hidden_layers
=
4
proposer
=
_create_proposer
(
num_speculative_tokens
=
1
,
layer_ids
=
list
(
range
(
num_hidden_layers
))
)
hidden_size
=
proposer
.
hidden_size
# Create mock model
model_mock
=
mock
.
MagicMock
()
proposer
.
model
=
model_mock
# Mock attention layer names
proposer
.
attn_layer_names
=
[
"cache_only_layers.28"
]
# Mock attention metadata builder
mock_attn_metadata
=
mock
.
MagicMock
()
mock_attn_metadata_builder
=
mock
.
MagicMock
()
mock_attn_metadata_builder
.
build_for_drafting
.
return_value
=
mock_attn_metadata
proposer
.
attn_metadata_builder
=
mock_attn_metadata_builder
# Create input tensors
batch_spec
=
BatchSpec
(
seq_lens
=
[
3
,
2
],
query_lens
=
[
3
,
2
],
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
,
)
# Create target hidden states: list of tensors, one per layer
# Each tensor has shape [num_tokens, hidden_size]
target_hidden_states
=
[
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
proposer
.
dtype
,
device
=
device
)
for
_
in
range
(
num_hidden_layers
)
]
# Sampled token IDs from target model
sampled_token_ids
=
torch
.
tensor
([
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
)
# Mock scheduler output
mock_scheduler_output
=
mock
.
MagicMock
()
# Call propose
with
mock
.
patch
(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
)
as
mock_has_kv
:
mock_has_kv
.
return_value
=
False
draft_tokens
,
kv_connector_output
=
proposer
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
mock_scheduler_output
,
slot_mappings
=
None
,
)
# Verify draft tokens match sampled tokens
# Shape should be [batch_size, 1] for num_speculative_tokens=1
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
torch
.
equal
(
draft_tokens
[:,
0
],
sampled_token_ids
)
# Verify the model was called
model_mock
.
assert_called_once
()
# Verify hidden states were copied to the buffer The stacked hidden states
# should have shape [num_tokens, num_hidden_layers, hidden_size]
expected_stacked
=
torch
.
stack
(
target_hidden_states
,
dim
=
1
)
assert
torch
.
allclose
(
proposer
.
hidden_states
[:
num_tokens
],
expected_stacked
,
atol
=
1e-6
)
@
pytest
.
mark
.
parametrize
(
"num_hidden_layers"
,
[
1
,
4
,
8
])
def
test_propose_different_layer_counts
(
num_hidden_layers
):
"""Test that propose works correctly with different numbers of hidden layers."""
device
=
torch
.
device
(
current_platform
.
device_type
)
batch_size
=
2
num_tokens
=
5
proposer
=
_create_proposer
(
num_speculative_tokens
=
1
,
layer_ids
=
list
(
range
(
num_hidden_layers
))
)
hidden_size
=
proposer
.
hidden_size
# Setup mocks
model_mock
=
mock
.
MagicMock
()
proposer
.
model
=
model_mock
proposer
.
attn_layer_names
=
[
"cache_only_layers.28"
]
mock_attn_metadata_builder
=
mock
.
MagicMock
()
mock_attn_metadata_builder
.
build_for_drafting
.
return_value
=
mock
.
MagicMock
()
proposer
.
attn_metadata_builder
=
mock_attn_metadata_builder
batch_spec
=
BatchSpec
(
seq_lens
=
[
3
,
2
],
query_lens
=
[
3
,
2
],
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
,
)
# Create target hidden states
target_hidden_states
=
[
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
proposer
.
dtype
,
device
=
device
)
for
_
in
range
(
num_hidden_layers
)
]
sampled_token_ids
=
torch
.
tensor
([
42
,
60
],
dtype
=
torch
.
int32
,
device
=
device
)
mock_scheduler_output
=
mock
.
MagicMock
()
with
mock
.
patch
(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
)
as
mock_has_kv
:
mock_has_kv
.
return_value
=
False
draft_tokens
,
_
=
proposer
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
mock_scheduler_output
,
slot_mappings
=
None
,
)
assert
draft_tokens
.
shape
==
(
batch_size
,
1
)
assert
torch
.
equal
(
draft_tokens
[:,
0
],
sampled_token_ids
)
vllm/config/speculative.py
View file @
9433acb8
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
import
ast
import
copy
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
get_args
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
get_args
from
pydantic
import
Field
,
SkipValidation
,
model_validator
from
pydantic
import
Field
,
SkipValidation
,
model_validator
...
@@ -45,7 +46,7 @@ MTPModelTypes = Literal[
...
@@ -45,7 +46,7 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp"
,
"pangu_ultra_moe_mtp"
,
"step3p5_mtp"
,
"step3p5_mtp"
,
]
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
MTPModelTypes
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
"extract_hidden_states"
,
MTPModelTypes
]
SpeculativeMethod
=
Literal
[
SpeculativeMethod
=
Literal
[
"ngram"
,
"ngram"
,
"medusa"
,
"medusa"
,
...
@@ -181,9 +182,22 @@ class SpeculativeConfig:
...
@@ -181,9 +182,22 @@ class SpeculativeConfig:
the final hidden states.
the final hidden states.
"""
"""
factors
:
list
[
Any
]
=
[]
factors
:
list
[
Any
]
=
[]
# Eagle3 affects the computation graph because it returns intermediate
# Eagle3 and extract_hidden_states affect the computation graph because
# hidden states in addition to the final hidden state.
# they return intermediate hidden states in addition to the final hidden state.
factors
.
append
(
self
.
method
==
"eagle3"
)
uses_aux_hidden_states
=
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
)
factors
.
append
(
uses_aux_hidden_states
)
# The specific layers used also affect the computation graph
if
uses_aux_hidden_states
and
self
.
draft_model_config
is
not
None
:
layer_ids
=
getattr
(
self
.
draft_model_config
.
hf_config
,
"eagle_aux_hidden_state_layer_ids"
,
None
,
)
if
layer_ids
is
not
None
:
# Convert to tuple to make it hashable
factors
.
append
(
tuple
(
layer_ids
))
hash_str
=
safe_hash
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
hash_str
=
safe_hash
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
return
hash_str
return
hash_str
...
@@ -352,6 +366,8 @@ class SpeculativeConfig:
...
@@ -352,6 +366,8 @@ class SpeculativeConfig:
self
.
model
=
"ngram"
self
.
model
=
"ngram"
elif
self
.
method
==
"suffix"
:
elif
self
.
method
==
"suffix"
:
self
.
model
=
"suffix"
self
.
model
=
"suffix"
elif
self
.
method
==
"extract_hidden_states"
:
self
.
model
=
"extract_hidden_states"
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"num_speculative_tokens was provided but without speculative model."
"num_speculative_tokens was provided but without speculative model."
...
@@ -394,6 +410,34 @@ class SpeculativeConfig:
...
@@ -394,6 +410,34 @@ class SpeculativeConfig:
self
.
draft_parallel_config
=
self
.
target_parallel_config
self
.
draft_parallel_config
=
self
.
target_parallel_config
elif
self
.
method
==
"suffix"
:
elif
self
.
method
==
"suffix"
:
self
.
_validate_suffix_decoding
()
self
.
_validate_suffix_decoding
()
elif
self
.
method
==
"extract_hidden_states"
:
from
vllm.transformers_utils.configs.extract_hidden_states
import
(
ExtractHiddenStatesConfig
,
)
# ExtractHiddenStatesModel is instantiated manually in load_model()
# We just need to store the target model config for KV cache shape info
self
.
model
=
"extract_hidden_states"
self
.
prompt_lookup_max
=
0
self
.
prompt_lookup_min
=
0
if
hasattr
(
self
.
draft_model_config
,
"hf_config"
):
hf_config
=
self
.
draft_model_config
.
hf_config
.
to_dict
()
elif
(
isinstance
(
self
.
draft_model_config
,
dict
)
and
"hf_config"
in
self
.
draft_model_config
):
hf_config
=
self
.
draft_model_config
[
"hf_config"
]
else
:
hf_config
=
{}
self
.
draft_model_config
=
copy
.
copy
(
self
.
target_model_config
)
self
.
draft_model_config
.
hf_config
=
ExtractHiddenStatesConfig
(
self
.
draft_model_config
.
hf_config
,
**
hf_config
)
self
.
update_arch_
()
self
.
draft_parallel_config
=
self
.
target_parallel_config
else
:
else
:
self
.
prompt_lookup_max
=
0
self
.
prompt_lookup_max
=
0
self
.
prompt_lookup_min
=
0
self
.
prompt_lookup_min
=
0
...
@@ -478,23 +522,8 @@ class SpeculativeConfig:
...
@@ -478,23 +522,8 @@ class SpeculativeConfig:
method
=
self
.
method
,
method
=
self
.
method
,
model_type
=
"eagle"
,
model_type
=
"eagle"
,
)
)
# EAGLEConfig primarily updates architectures, so update
# all architectures-related fields in draft_model_config
self
.
draft_model_config
.
hf_config
=
eagle_config
self
.
draft_model_config
.
hf_config
=
eagle_config
self
.
draft_model_config
.
hf_text_config
=
get_hf_text_config
(
self
.
update_arch_
()
self
.
draft_model_config
.
hf_config
)
self
.
draft_model_config
.
model_arch_config
=
(
self
.
draft_model_config
.
get_model_arch_config
()
)
model_info
,
arch
=
(
self
.
draft_model_config
.
registry
.
inspect_model_cls
(
self
.
draft_model_config
.
architectures
,
self
.
draft_model_config
,
)
)
self
.
draft_model_config
.
_model_info
=
model_info
self
.
draft_model_config
.
_architecture
=
arch
if
self
.
num_speculative_tokens
is
not
None
and
hasattr
(
if
self
.
num_speculative_tokens
is
not
None
and
hasattr
(
self
.
draft_model_config
.
hf_config
,
"num_lookahead_tokens"
self
.
draft_model_config
.
hf_config
,
"num_lookahead_tokens"
...
@@ -671,6 +700,24 @@ class SpeculativeConfig:
...
@@ -671,6 +700,24 @@ class SpeculativeConfig:
)
)
return
speculative_draft_tensor_parallel_size
return
speculative_draft_tensor_parallel_size
def
update_arch_
(
self
):
"""
EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
architectures-related fields in self.draft_model_config
"""
self
.
draft_model_config
.
hf_text_config
=
get_hf_text_config
(
self
.
draft_model_config
.
hf_config
)
self
.
draft_model_config
.
model_arch_config
=
(
self
.
draft_model_config
.
get_model_arch_config
()
)
model_info
,
arch
=
self
.
draft_model_config
.
registry
.
inspect_model_cls
(
self
.
draft_model_config
.
architectures
,
self
.
draft_model_config
,
)
self
.
draft_model_config
.
_model_info
=
model_info
self
.
draft_model_config
.
_architecture
=
arch
@
staticmethod
@
staticmethod
def
create_draft_parallel_config
(
def
create_draft_parallel_config
(
target_parallel_config
:
ParallelConfig
,
target_parallel_config
:
ParallelConfig
,
...
@@ -718,7 +765,7 @@ class SpeculativeConfig:
...
@@ -718,7 +765,7 @@ class SpeculativeConfig:
self
.
draft_parallel_config
self
.
draft_parallel_config
)
)
eagle3_target
_supported
=
[
aux_hidden_states
_supported
=
[
"llama"
,
"llama"
,
"qwen"
,
"qwen"
,
"minicpm"
,
"minicpm"
,
...
@@ -729,16 +776,16 @@ class SpeculativeConfig:
...
@@ -729,16 +776,16 @@ class SpeculativeConfig:
"nemotron_h"
,
"nemotron_h"
,
]
]
if
(
if
(
self
.
method
==
"eagle3"
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
)
and
self
.
target_model_config
and
self
.
target_model_config
and
not
any
(
and
not
any
(
supported_model
in
self
.
target_model_config
.
hf_text_config
.
model_type
supported_model
in
self
.
target_model_config
.
hf_text_config
.
model_type
for
supported_model
in
eagle3_target
_supported
for
supported_model
in
aux_hidden_states
_supported
)
)
):
):
raise
ValueError
(
raise
ValueError
(
f
"
Eagle3
is only supported for
{
eagle3_target_supported
}
models. "
# noqa: E501
f
"
{
self
.
method
}
is only supported for
{
aux_hidden_states_supported
}
"
f
"Got
{
self
.
target_model_config
.
hf_text_config
.
model_type
=
}
"
f
"
models.
Got
{
self
.
target_model_config
.
hf_text_config
.
model_type
=
}
"
)
)
self
.
verify_equal_vocab_size_if_draft_model
()
self
.
verify_equal_vocab_size_if_draft_model
()
return
self
return
self
...
@@ -782,8 +829,15 @@ class SpeculativeConfig:
...
@@ -782,8 +829,15 @@ class SpeculativeConfig:
def
uses_draft_model
(
self
)
->
bool
:
def
uses_draft_model
(
self
)
->
bool
:
return
self
.
method
==
"draft_model"
return
self
.
method
==
"draft_model"
def
uses_extract_hidden_states
(
self
)
->
bool
:
return
self
.
method
==
"extract_hidden_states"
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
method
=
self
.
method
model
=
None
if
method
in
(
"ngram"
,
"suffix"
)
else
self
.
draft_model_config
.
model
model
=
(
None
if
method
in
(
"ngram"
,
"suffix"
,
"extract_hidden_states"
)
else
self
.
draft_model_config
.
model
)
num_spec_tokens
=
self
.
num_speculative_tokens
num_spec_tokens
=
self
.
num_speculative_tokens
return
f
"SpeculativeConfig(
{
method
=
}
,
{
model
=
}
,
{
num_spec_tokens
=
}
)"
return
f
"SpeculativeConfig(
{
method
=
}
,
{
model
=
}
,
{
num_spec_tokens
=
}
)"
vllm/distributed/kv_events.py
View file @
9433acb8
...
@@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC):
...
@@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC):
def
clear_events
(
self
)
->
None
:
def
clear_events
(
self
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
merge
(
self
,
other
:
"KVConnectorKVEvents"
)
->
"KVConnectorKVEvents"
:
self
.
add_events
(
other
.
get_all_events
())
return
self
class
EventPublisher
(
ABC
):
class
EventPublisher
(
ABC
):
"""Lightweight publisher for EventBatch batches with data parallelism
"""Lightweight publisher for EventBatch batches with data parallelism
...
...
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
9433acb8
...
@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
...
@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
"ExampleConnector"
,
"ExampleConnector"
,
)
)
KVConnectorFactory
.
register_connector
(
"ExampleHiddenStatesConnector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector"
,
"ExampleHiddenStatesConnector"
,
)
KVConnectorFactory
.
register_connector
(
KVConnectorFactory
.
register_connector
(
"P2pNcclConnector"
,
"P2pNcclConnector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector"
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
0 → 100644
View file @
9433acb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
safetensors
import
torch
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
,
)
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.core.sched.output
import
NewRequestData
,
SchedulerOutput
if
TYPE_CHECKING
:
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
def
extract_from_kv_cache
(
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
num_tokens
:
int
,
)
->
torch
.
Tensor
:
"""Extract data from KV cache
Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
"""
padded_kv
=
kv_cache
.
flatten
(
0
,
1
)[
slot_mapping
]
# shape: [len(slot_mapping), num_heads, head_size]
return
padded_kv
[:
num_tokens
]
# shape: [num_tokens, num_heads, head_size]
@
dataclass
class
ReqMeta
:
# Request ID
req_id
:
str
# Request filename
filename
:
str
# Request tokens
token_ids
:
torch
.
Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping
:
torch
.
Tensor
# Whether this request is a new request or partially computed already
new_req
:
bool
@
staticmethod
def
make_meta
(
req_id
:
str
,
filename
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
block_size
:
int
,
new_req
:
bool
,
)
->
"ReqMeta"
:
token_ids_tensor
=
torch
.
tensor
(
token_ids
)
block_ids_tensor
=
torch
.
tensor
(
block_ids
)
num_blocks
=
block_ids_tensor
.
shape
[
0
]
block_offsets
=
torch
.
arange
(
0
,
block_size
)
slot_mapping
=
(
block_offsets
.
reshape
((
1
,
block_size
))
+
block_ids_tensor
.
reshape
((
num_blocks
,
1
))
*
block_size
)
slot_mapping
=
slot_mapping
.
flatten
()
return
ReqMeta
(
req_id
=
req_id
,
filename
=
filename
,
token_ids
=
token_ids_tensor
,
slot_mapping
=
slot_mapping
,
new_req
=
new_req
,
)
@
dataclass
class
ExampleHiddenStatesConnectorMetadata
(
KVConnectorMetadata
):
requests
:
list
[
ReqMeta
]
=
field
(
default_factory
=
list
)
def
add_request
(
self
,
req_id
:
str
,
filename
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
block_size
:
int
,
new_req
:
bool
=
True
,
)
->
None
:
self
.
requests
.
append
(
ReqMeta
.
make_meta
(
req_id
,
filename
,
token_ids
,
block_ids
,
block_size
,
new_req
)
)
class
ExampleHiddenStatesConnector
(
KVConnectorBase_V1
):
"""
Simple debug implementation of a HiddenStatesConnector.
Simply extracts the hidden states from the kv cache and stores them to disk.
Must be used in conjunction with the `extract_hidden_states` spec decoding method.
"""
@
property
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
"""
Indicates whether this connector prefers KV blocks that hold KV data for all
layers, which can speed up KV data transfers. Defaults to False.
"""
# Must be False so that drafter kv cache isn't merged with verifier's
return
False
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
,
kv_cache_config
=
kv_cache_config
,
)
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
_storage_path
=
self
.
_kv_transfer_config
.
get_from_extra_config
(
"shared_storage_path"
,
"/tmp"
)
self
.
cache_layers
:
list
[
str
]
=
[]
# set by self.register_kv_caches
logger
.
info
(
self
.
_kv_transfer_config
)
logger
.
info
(
"Shared storage path is %s"
,
self
.
_storage_path
)
assert
self
.
_vllm_config
.
speculative_config
is
not
None
,
(
"ExampleHiddenStatesConnector only works when using "
"'extract_hidden_states' speculative method"
)
spec_config
=
self
.
_vllm_config
.
speculative_config
.
draft_model_config
.
hf_config
self
.
num_hidden_states
=
len
(
getattr
(
spec_config
,
"eagle_aux_hidden_state_layer_ids"
,
[])
)
self
.
_request_filenames
:
dict
[
str
,
str
]
=
{}
self
.
_active_requests
:
dict
[
str
,
NewRequestData
]
=
{}
self
.
_req_blocks
:
dict
[
str
,
list
[
int
]]
=
{}
# ==============================
# Worker-side methods
# ==============================
def
start_load_kv
(
self
,
*
args
,
**
kwargs
:
Any
)
->
None
:
pass
# Empty implementation of abstract method
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
pass
# Empty implementation of abstract method
def
wait_for_save
(
self
):
pass
# Empty implementation of abstract method
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
from
vllm.model_executor.models.extract_hidden_states
import
(
CacheOnlyAttentionLayer
,
)
# Filter layers to only include CacheOnlyAttentionLayers
layers
=
get_layers_from_vllm_config
(
self
.
_vllm_config
,
CacheOnlyAttentionLayer
,
list
(
kv_caches
.
keys
())
)
self
.
cache_layers
=
list
(
layers
.
keys
())
assert
len
(
self
.
cache_layers
)
==
1
,
(
f
"Expected 1 CacheOnlyAttentionLayer, got
{
len
(
self
.
cache_layers
)
}
"
)
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
Any
,
)
->
None
:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
if
layer_name
not
in
self
.
cache_layers
:
return
from
vllm.model_executor.models.extract_hidden_states
import
(
CacheOnlyAttentionMetadata
,
)
assert
isinstance
(
attn_metadata
,
CacheOnlyAttentionMetadata
),
(
"ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend"
)
connector_metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
ExampleHiddenStatesConnectorMetadata
)
os
.
makedirs
(
self
.
_storage_path
,
exist_ok
=
True
)
for
request
in
connector_metadata
.
requests
:
hidden_states
=
extract_from_kv_cache
(
kv_layer
,
request
.
slot_mapping
,
request
.
token_ids
.
shape
[
0
]
)
tensors
=
{
"hidden_states"
:
hidden_states
.
detach
().
cpu
(),
"token_ids"
:
request
.
token_ids
.
detach
().
cpu
(),
}
safetensors
.
torch
.
save_file
(
tensors
,
request
.
filename
)
# ==============================
# Scheduler-side methods
# ==============================
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
|
None
,
bool
]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# This connector is store-only, so we don't need to load any tokens
return
0
,
False
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
):
# Usually used to handle allocation of new blocks for requests that are loading
# tokens from connector's external kv cache. We never load from external cache
# so this is a no-op.
assert
num_external_tokens
==
0
,
"This connector is store-only"
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
,
)
->
KVConnectorMetadata
:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta
=
ExampleHiddenStatesConnectorMetadata
()
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
token_ids
=
new_req
.
prompt_token_ids
or
[]
filename
=
os
.
path
.
join
(
self
.
_storage_path
,
f
"
{
new_req
.
req_id
}
.safetensors"
)
meta
.
add_request
(
new_req
.
req_id
,
filename
=
filename
,
token_ids
=
token_ids
,
block_ids
=
new_req
.
block_ids
[
0
],
block_size
=
self
.
_block_size
,
)
self
.
_request_filenames
[
new_req
.
req_id
]
=
filename
self
.
_active_requests
[
new_req
.
req_id
]
=
new_req
self
.
_req_blocks
[
new_req
.
req_id
]
=
list
(
new_req
.
block_ids
[
0
])
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
if
req_id
not
in
self
.
_active_requests
:
continue
new_block_ids
=
cached_reqs
.
new_block_ids
[
i
]
cached_req
=
self
.
_active_requests
[
req_id
]
req_block_ids
=
self
.
_req_blocks
[
req_id
]
assert
new_block_ids
is
not
None
block_ids
=
new_block_ids
[
0
]
req_block_ids
.
extend
(
block_ids
)
filename
=
os
.
path
.
join
(
self
.
_storage_path
,
f
"
{
req_id
}
.safetensors"
)
meta
.
add_request
(
req_id
=
req_id
,
filename
=
filename
,
token_ids
=
cached_req
.
prompt_token_ids
or
[],
block_ids
=
req_block_ids
,
block_size
=
self
.
_block_size
,
new_req
=
False
,
)
return
meta
def
request_finished
(
self
,
request
:
"Request"
,
block_ids
:
list
[
int
],
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
"""
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id
=
request
.
request_id
req_filename
=
self
.
_request_filenames
.
pop
(
req_id
,
None
)
_
=
self
.
_active_requests
.
pop
(
req_id
,
None
)
_
=
self
.
_req_blocks
.
pop
(
req_id
,
None
)
return
False
,
{
"hidden_states_path"
:
req_filename
}
@
classmethod
def
get_required_kvcache_layout
(
cls
,
vllm_config
:
"VllmConfig"
)
->
str
|
None
:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if
cls
is
KVConnectorBase_V1
:
raise
TypeError
(
"get_required_kvcache_layout should not be called "
"on the abstract base class"
)
# NHD means we have (num_tokens, num_heads)
# HND means we have (num_heads, num_tokens)
# For now, we only support NHD layout since this keeps the
# hidden states for each token together in memory.
# HND is primarily used when sharding heads across devices.
return
"NHD"
vllm/model_executor/models/extract_hidden_states.py
0 → 100644
View file @
9433acb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Hidden States Extractor Model.
This model extracts and caches hidden states from the target model
without performing actual token generation. It's used with the
extract_hidden_states speculative decoding method.
"""
from
collections.abc
import
Iterable
from
typing
import
ClassVar
import
torch
import
torch.nn
as
nn
from
vllm.config
import
CacheConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.config.cache
import
CacheDType
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.attention.attention
import
set_default_quant_scales
from
vllm.model_executor.layers.attention.kv_transfer_utils
import
(
maybe_transfer_kv_layer
,
)
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.models.utils
import
maybe_prefix
from
vllm.utils.torch_utils
import
kv_cache_dtype_str_to_dtype
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadataBuilder
,
AttentionType
,
CommonAttentionMetadata
,
is_quantized_kv_cache
,
)
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheSpec
,
MLAAttentionSpec
,
)
########## Custom Ops ########
def
unified_kv_cache_update
(
to_cache
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
"""
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
forward_context
=
get_forward_context
()
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
attn_layer
.
kv_cache
[
forward_context
.
virtual_engine
]
slot_mapping
=
forward_context
.
slot_mapping
assert
isinstance
(
slot_mapping
,
dict
),
(
f
"Expected slot_mapping to be a dict, got
{
type
(
slot_mapping
)
}
. "
)
layer_slot_mapping
=
slot_mapping
.
get
(
layer_name
)
if
layer_slot_mapping
is
not
None
:
assert
hasattr
(
attn_layer
.
impl
,
"do_kv_cache_update"
),
(
f
"
{
attn_layer
.
impl
.
__class__
.
__name__
}
does not support kv cache update"
)
attn_layer
.
impl
.
do_kv_cache_update
(
attn_layer
,
to_cache
,
kv_cache
,
layer_slot_mapping
,
)
return
torch
.
empty
(
0
,
device
=
kv_cache
.
device
,
dtype
=
kv_cache
.
dtype
)
@
maybe_transfer_kv_layer
def
dummy_attention
(
layer_name
,
_placeholder
):
# Note: layer_name arg required by @maybe_transfer_kv_layer
return
_placeholder
def
basic_cache
(
to_cache
:
torch
.
Tensor
,
# shape: [num_blocks, block_size, num_heads, head_size]
kv_cache
:
torch
.
Tensor
,
# shape: [seq_len, num_heads, head_size]
slot_mapping
:
torch
.
Tensor
,
# shape: [seq_len]
):
num_blocks
,
block_size
,
num_heads
,
head_size
=
kv_cache
.
shape
token_kv_cache
=
kv_cache
.
view
(
num_blocks
*
block_size
,
num_heads
,
head_size
)
token_kv_cache
[
slot_mapping
]
=
to_cache
######### CacheOnlyAttentionBackend ########
class
CacheOnlyAttentionBackend
(
AttentionBackend
):
"""Attention backend that only caches KV without computing attention."""
accept_output_buffer
:
bool
=
False
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"bfloat16"
,
]
forward_includes_kv_cache_update
:
bool
=
False
@
staticmethod
def
get_name
()
->
str
:
return
"CACHE_ONLY_ATTN"
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
return
attn_type
==
AttentionType
.
DECODER
@
classmethod
def
supports_mm_prefix
(
cls
)
->
bool
:
return
True
@
staticmethod
def
get_impl_cls
()
->
type
[
"CacheOnlyAttentionImpl"
]:
return
CacheOnlyAttentionImpl
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
int
,
...]:
# We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
# We also don't use a k/v (2) dim
return
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_builder_cls
()
->
type
[
"CacheOnlyAttentionMetadataBuilder"
]:
return
CacheOnlyAttentionMetadataBuilder
@
staticmethod
def
use_cascade_attention
(
*
args
,
**
kwargs
)
->
bool
:
return
False
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[]
class
CacheOnlyAttentionMetadata
:
def
__init__
(
self
,
slot_mapping
:
torch
.
Tensor
):
self
.
slot_mapping
=
slot_mapping
class
CacheOnlyAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
CacheOnlyAttentionMetadata
]
):
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
CacheOnlyAttentionMetadata
:
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
raise
NotImplementedError
(
"Cascade attention not supported by CacheOnlyAttention"
)
causal
=
common_attn_metadata
.
causal
if
not
causal
:
raise
NotImplementedError
(
"Non-causal attention not supported by CacheOnlyAttention"
)
return
CacheOnlyAttentionMetadata
(
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
)
class
CacheOnlyAttentionImpl
(
AttentionImpl
):
"""Attention implementation that only caches KV states."""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
kv_cache_dtype
:
str
,
kv_cache_torch_dtype
:
torch
.
dtype
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_torch_dtype
=
kv_cache_torch_dtype
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
f
"Unsupported attention type:
{
attn_type
}
"
)
if
is_quantized_kv_cache
(
kv_cache_dtype
):
raise
NotImplementedError
(
"Quantized KV cache not supported"
)
self
.
num_queries_per_kv
=
1
def
do_kv_cache_update
(
self
,
layer
,
to_cache
,
kv_cache
,
slot_mapping
,
):
assert
to_cache
.
dtype
==
self
.
kv_cache_torch_dtype
,
(
f
"Data to cache must be
{
self
.
kv_cache_torch_dtype
}
, got
{
to_cache
.
dtype
}
"
)
assert
kv_cache
.
dtype
==
self
.
kv_cache_torch_dtype
,
(
f
"KV cache must be
{
self
.
kv_cache_torch_dtype
}
, got
{
kv_cache
.
dtype
}
"
)
basic_cache
(
to_cache
,
kv_cache
,
slot_mapping
)
def
forward
(
self
,
*
args
,
**
kwargs
):
# Empty implementation of abstract method
pass
############## CacheOnlyAttentionLayer (replaces Attention) ############
class
CacheOnlyAttentionLayer
(
nn
.
Module
,
AttentionLayerBase
):
"""Attention layer that only caches key/value states without computing attention."""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
cache_config
:
CacheConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
layer_name
=
prefix
vllm_config
=
get_current_vllm_config
()
# KV cache configuration
cache_config
=
cache_config
or
vllm_config
.
cache_config
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
self
.
block_size
=
cache_config
.
block_size
else
:
kv_cache_dtype
=
"auto"
self
.
block_size
=
16
assert
kv_cache_dtype
in
[
"auto"
,
"bfloat16"
,
"float16"
],
(
"CacheOnlyAttentionLayer doesn't currently support quantized kv cache but"
f
"kv cache dtype was set to
{
kv_cache_dtype
}
"
)
self
.
kv_cache_torch_dtype
=
kv_cache_dtype_str_to_dtype
(
kv_cache_dtype
,
vllm_config
.
model_config
)
# Initialize KV cache quantization attributes
set_default_quant_scales
(
self
,
register_buffer
=
True
)
# Attention backend
self
.
attn_backend
=
CacheOnlyAttentionBackend
impl_cls
=
self
.
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
kv_cache_dtype
,
self
.
kv_cache_torch_dtype
,
attn_type
,
)
assert
not
self
.
attn_backend
.
forward_includes_kv_cache_update
,
(
"KV cache update should be independent of forward"
)
# Placeholder KV cache (replaced by bind_kv_cache)
self
.
kv_cache
=
[
torch
.
tensor
([])
for
_
in
range
(
vllm_config
.
parallel_config
.
pipeline_parallel_size
)
]
# Register in compilation context
compilation_config
=
vllm_config
.
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
forward
(
self
,
to_cache
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Cache hidden states as KV pairs without computing attention.
Args:
to_cache: The tensor to insert into the kv cache.
shape [num_tokens, num_heads, head_size]
Returns:
Dummy output tensor (not used)
"""
# Note: we set num_heads to num_hidden_layers and
# head_size to hidden_size for hidden states storage
output
=
torch
.
empty
(
0
,
device
=
to_cache
.
device
,
dtype
=
to_cache
.
dtype
)
# Note: dummy_out is used to force torch.compile to preserve ordering between
# cache update and attention op (which triggers kv_connector transfer)
dummy_out
=
unified_kv_cache_update
(
to_cache
,
self
.
layer_name
)
# Triggers kv_connector transfer via decorator
_
=
dummy_attention
(
self
.
layer_name
,
dummy_out
)
return
output
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
return
self
.
attn_backend
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
# Note: we use MLAAttentionSpec here to because it will
# produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
# whereas FullAttentionSpec will add an additional factor of 2
return
MLAAttentionSpec
(
block_size
=
self
.
block_size
,
num_kv_heads
=
self
.
num_heads
,
head_size
=
self
.
head_size
,
dtype
=
self
.
kv_cache_torch_dtype
,
)
############ ExtractHiddenStatesModel definition ##########
class
ExtractHiddenStatesModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
vllm_config
=
vllm_config
self
.
hf_config
=
vllm_config
.
speculative_config
.
draft_model_config
.
hf_config
self
.
hidden_size
=
vllm_config
.
model_config
.
get_hidden_size
()
self
.
target_num_hidden_layers
=
(
vllm_config
.
model_config
.
get_total_num_hidden_layers
()
)
self
.
num_hidden_states
=
len
(
getattr
(
self
.
hf_config
,
"eagle_aux_hidden_state_layer_ids"
,
[])
)
cache_config
=
vllm_config
.
cache_config
# Create a single cache-only attention layer
# Note: We set num_heads <- self.num_hidden_states
# and head_size <- hidden_size so that we can insert
# the hidden states directly into the cache without
# reshaping
self
.
cache_only_layers
=
nn
.
ModuleDict
(
{
str
(
self
.
target_num_hidden_layers
):
CacheOnlyAttentionLayer
(
num_heads
=
self
.
num_hidden_states
,
head_size
=
self
.
hidden_size
,
cache_config
=
cache_config
,
prefix
=
maybe_prefix
(
prefix
,
f
"cache_only_layers.
{
self
.
target_num_hidden_layers
}
"
),
)
}
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
None
:
"""Process and cache hidden states.
Args:
hidden_states: Hidden states from target model
shape: [num_tokens, num_hidden_states, hidden_size]
Returns:
Tuple of (dummy_output, dummy_output) - both unused
"""
# Call dummy attention layer to cache hidden states
# Output is ignored - we only care about the KV cache side effects
_
=
self
.
cache_only_layers
[
str
(
self
.
target_num_hidden_layers
)](
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
"""No weights to load for this dummy model."""
return
set
()
vllm/model_executor/models/registry.py
View file @
9433acb8
...
@@ -512,6 +512,7 @@ _MULTIMODAL_MODELS = {
...
@@ -512,6 +512,7 @@ _MULTIMODAL_MODELS = {
}
}
_SPECULATIVE_DECODING_MODELS
=
{
_SPECULATIVE_DECODING_MODELS
=
{
"ExtractHiddenStatesModel"
:
(
"extract_hidden_states"
,
"ExtractHiddenStatesModel"
),
"MiMoMTPModel"
:
(
"mimo_mtp"
,
"MiMoMTP"
),
"MiMoMTPModel"
:
(
"mimo_mtp"
,
"MiMoMTP"
),
"EagleLlamaForCausalLM"
:
(
"llama_eagle"
,
"EagleLlamaForCausalLM"
),
"EagleLlamaForCausalLM"
:
(
"llama_eagle"
,
"EagleLlamaForCausalLM"
),
"EagleLlama4ForCausalLM"
:
(
"llama4_eagle"
,
"EagleLlama4ForCausalLM"
),
"EagleLlama4ForCausalLM"
:
(
"llama4_eagle"
,
"EagleLlama4ForCausalLM"
),
...
...
vllm/transformers_utils/configs/extract_hidden_states.py
0 → 100644
View file @
9433acb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Config definitions for ExtractHiddenStatesModel, to be used with
the extract_hidden_states spec decoding method."""
import
os
from
transformers
import
PretrainedConfig
class
ExtractHiddenStatesConfig
(
PretrainedConfig
):
model_type
=
"extract_hidden_states"
def
__init__
(
self
,
model
:
PretrainedConfig
|
dict
|
None
=
None
,
method
:
str
|
None
=
"extract_hidden_states"
,
**
kwargs
,
):
assert
method
==
"extract_hidden_states"
if
isinstance
(
model
,
dict
):
model_dict
=
model
elif
isinstance
(
model
,
PretrainedConfig
):
model_dict
=
model
.
to_dict
()
else
:
model_dict
=
{}
# Combine: model_dict first, then kwargs override
combined
=
{
**
model_dict
,
**
kwargs
}
# Remove architectures from the base, we'll set it explicitly
combined
=
{
k
:
v
for
k
,
v
in
combined
.
items
()
if
k
!=
"architectures"
}
combined
[
"architectures"
]
=
[
"ExtractHiddenStatesModel"
]
super
().
__init__
(
**
combined
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
,
)
->
"ExtractHiddenStatesConfig"
:
config_dict
,
kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
def
to_json_string
(
self
,
use_diff
:
bool
=
True
)
->
str
:
# we override use_diff to False as initializing
# ExtractHiddenStatesConfig with default arguments is not supported
del
use_diff
return
super
().
to_json_string
(
use_diff
=
False
)
vllm/v1/outputs.py
View file @
9433acb8
...
@@ -2,8 +2,9 @@
...
@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
TypeAlias
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
TypeAlias
,
TypeVar
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -120,6 +121,20 @@ class SamplerOutput:
...
@@ -120,6 +121,20 @@ class SamplerOutput:
logprobs_tensors
:
LogprobsTensors
|
None
logprobs_tensors
:
LogprobsTensors
|
None
T
=
TypeVar
(
"T"
)
def
_combine_non_none
(
f
:
Callable
[[
T
,
T
],
T
],
items
:
list
[
T
|
None
])
->
T
|
None
:
non_none
=
[
item
for
item
in
items
if
item
is
not
None
]
if
len
(
non_none
)
==
0
:
return
None
combined
=
non_none
[
0
]
for
item
in
non_none
[
1
:]:
combined
=
f
(
combined
,
item
)
return
combined
@
dataclass
@
dataclass
class
KVConnectorOutput
:
class
KVConnectorOutput
:
# [req_ids]
# [req_ids]
...
@@ -146,6 +161,43 @@ class KVConnectorOutput:
...
@@ -146,6 +161,43 @@ class KVConnectorOutput:
and
not
self
.
invalid_block_ids
and
not
self
.
invalid_block_ids
)
)
@
classmethod
def
merge
(
cls
,
*
outputs
:
"KVConnectorOutput"
):
assert
len
(
outputs
)
>
0
,
"Cannot merge empty outputs"
finished_sending
=
_combine_non_none
(
set
.
union
,
[
output
.
finished_sending
for
output
in
outputs
]
)
finished_recving
=
_combine_non_none
(
set
.
union
,
[
output
.
finished_recving
for
output
in
outputs
]
)
kv_connector_stats
=
_combine_non_none
(
lambda
x
,
y
:
x
.
aggregate
(
y
),
[
output
.
kv_connector_stats
for
output
in
outputs
],
)
kv_cache_events
=
_combine_non_none
(
lambda
x
,
y
:
x
.
merge
(
y
),
[
output
.
kv_cache_events
for
output
in
outputs
],
)
invalid_block_ids
=
_combine_non_none
(
set
.
union
,
[
output
.
invalid_block_ids
for
output
in
outputs
]
)
assert
invalid_block_ids
is
not
None
assert
all
(
output
.
expected_finished_count
==
outputs
[
0
].
expected_finished_count
for
output
in
outputs
)
expected_finished_count
=
outputs
[
0
].
expected_finished_count
return
cls
(
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
kv_connector_stats
=
kv_connector_stats
,
kv_cache_events
=
kv_cache_events
,
invalid_block_ids
=
invalid_block_ids
,
expected_finished_count
=
expected_finished_count
,
)
@
dataclass
@
dataclass
class
ECConnectorOutput
:
class
ECConnectorOutput
:
...
...
vllm/v1/spec_decode/extract_hidden_states.py
0 → 100644
View file @
9433acb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer
import
has_kv_transfer_group
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.model_loader
import
get_model
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
,
CommonAttentionMetadata
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
PADDING_SLOT_ID
=
-
1
class
ExtractHiddenStatesProposer
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
):
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
num_speculative_tokens
==
1
if
vllm_config
.
speculative_config
.
disable_padded_drafter_batch
:
raise
ValueError
(
"disable_padded_drafter_batch is not supported with "
"extract_hidden_states method"
)
self
.
vllm_config
=
vllm_config
self
.
device
=
device
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
# Model and attention layer tracking (initialized in load_model)
self
.
model
:
nn
.
Module
|
None
=
None
self
.
attn_layer_names
:
list
[
str
]
=
[]
self
.
attn_metadata_builder
:
AttentionMetadataBuilder
|
None
=
None
# Maximum number of tokens for buffers
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
)
self
.
hf_config
=
vllm_config
.
speculative_config
.
draft_model_config
.
hf_config
layer_ids
=
getattr
(
self
.
hf_config
,
"eagle_aux_hidden_state_layer_ids"
,
None
)
if
not
layer_ids
:
raise
ValueError
(
"eagle_aux_hidden_state_layer_ids must be set in the draft "
"model config for extract_hidden_states method"
)
self
.
num_hidden_states
=
len
(
layer_ids
)
self
.
hidden_size
=
vllm_config
.
model_config
.
get_hidden_size
()
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
num_hidden_states
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
,
)
self
.
cudagraph_dispatcher
=
CudagraphDispatcher
(
self
.
vllm_config
)
self
.
_slot_mapping_buffer
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
def
propose
(
self
,
sampled_token_ids
:
torch
.
Tensor
,
target_hidden_states
:
list
[
torch
.
Tensor
],
common_attn_metadata
:
CommonAttentionMetadata
,
scheduler_output
:
SchedulerOutput
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
KVConnectorOutput
|
None
]:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
without performing actual attention computation. This allows us to
extract and store hidden states for later use (e.g., KV transfer).
This proposer doesn't actually perform speculation - it returns the
sampled tokens as "draft" tokens, ensuring they always verify (match).
The main purpose is to cache hidden states, not to speculate.
Args:
sampled_token_ids: Sampled token IDs from the target model
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
Returns:
Tuple of:
- Draft tokens matching sampled tokens, shape [batch_size, 1]
- KV connector output (if KV transfer is active), else None
"""
assert
self
.
model
is
not
None
and
isinstance
(
target_hidden_states
,
list
)
# target_hidden_states is a list of tensors (one per layer)
# Each tensor has shape [num_tokens, hidden_size]
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
stacked_hidden_states
=
torch
.
stack
(
target_hidden_states
,
dim
=
1
)
num_tokens
=
stacked_hidden_states
.
shape
[
0
]
# Copy hidden states to buffer
self
.
hidden_states
[:
num_tokens
]
=
stacked_hidden_states
assert
self
.
attn_metadata_builder
is
not
None
attn_metadata
=
self
.
attn_metadata_builder
.
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
0
)
# We assume all cache-only layers belong to the same KV cache group,
# thus using the same attention metadata.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
)
)
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
with
(
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
),
),
(
KVConnectorModelRunnerMixin
.
_get_kv_connector_output
(
scheduler_output
)
if
has_kv_transfer_group
()
else
nullcontext
()
)
as
kv_connector_output
,
):
self
.
model
(
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
)
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return
sampled_token_ids
.
unsqueeze
(
-
1
),
kv_connector_output
def
_get_slot_mapping
(
self
,
num_tokens
:
int
,
slot_mapping
:
torch
.
Tensor
|
None
=
None
,
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Return slot_mapping dict for cache-only attention layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if
slot_mapping
is
not
None
:
num_actual
=
slot_mapping
.
shape
[
0
]
self
.
_slot_mapping_buffer
[:
num_actual
].
copy_
(
slot_mapping
)
if
num_tokens
>
num_actual
:
self
.
_slot_mapping_buffer
[
num_actual
:
num_tokens
].
fill_
(
PADDING_SLOT_ID
)
view
=
self
.
_slot_mapping_buffer
[:
num_tokens
]
return
{
name
:
view
for
name
in
self
.
attn_layer_names
}
def
_determine_batch_execution_and_padding
(
self
,
num_tokens
:
int
,
use_cudagraphs
:
bool
=
True
,
)
->
tuple
[
CUDAGraphMode
,
int
,
torch
.
Tensor
|
None
]:
cudagraph_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens
,
valid_modes
=
({
CUDAGraphMode
.
NONE
}
if
not
use_cudagraphs
else
None
),
)
num_tokens_padded
=
batch_desc
.
num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch
,
num_tokens_across_dp
=
False
,
None
if
self
.
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
should_ubatch
,
num_tokens_across_dp
,
synced_cudagraph_mode
=
(
coordinate_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
parallel_config
=
self
.
vllm_config
.
parallel_config
,
allow_microbatching
=
False
,
num_tokens_padded
=
num_tokens_padded
,
cudagraph_mode
=
cudagraph_mode
.
value
,
)
)
assert
not
should_ubatch
,
(
"DBO ubatching not implemented for extract_hidden_states"
)
# Extract DP-synced values
if
num_tokens_across_dp
is
not
None
:
dp_rank
=
self
.
dp_rank
num_tokens_padded
=
int
(
num_tokens_across_dp
[
dp_rank
].
item
())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens_padded
,
valid_modes
=
{
CUDAGraphMode
(
synced_cudagraph_mode
)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert
batch_desc
.
num_tokens
==
num_tokens_padded
num_tokens_across_dp
[
dp_rank
]
=
num_tokens_padded
return
cudagraph_mode
,
num_tokens_padded
,
num_tokens_across_dp
def
initialize_cudagraph_keys
(
self
,
cudagraph_mode
:
CUDAGraphMode
)
->
None
:
"""Initialize cudagraph dispatcher keys.
Only supports PIECEWISE cudagraphs (via mixed_mode).
Should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
assert
self
.
vllm_config
.
speculative_config
is
not
None
if
(
not
self
.
vllm_config
.
speculative_config
.
enforce_eager
and
cudagraph_mode
.
mixed_mode
()
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
]
):
proposer_cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
else
:
proposer_cudagraph_mode
=
CUDAGraphMode
.
NONE
self
.
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
proposer_cudagraph_mode
)
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
num_tokens
:
int
,
use_cudagraphs
:
bool
=
True
,
is_graph_capturing
:
bool
=
False
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
)
->
None
:
assert
self
.
model
is
not
None
,
"Model must be initialized before dummy_run"
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
,
use_cudagraphs
=
use_cudagraphs
)
)
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
# Use our own slot mapping buffer during cudagraph capture.
if
(
self
.
attn_layer_names
and
slot_mappings
is
not
None
and
self
.
attn_layer_names
[
0
]
in
slot_mappings
):
slot_mapping_dict
=
self
.
_get_slot_mapping
(
num_input_tokens
)
else
:
slot_mapping_dict
=
slot_mappings
or
{}
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
slot_mapping_dict
,
):
self
.
model
(
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
)
def
_build_attn_metadata_builder
(
self
,
draft_attn_layers
:
dict
[
str
,
AttentionLayerBase
]
)
->
AttentionMetadataBuilder
:
"""Build the attention metadata builder from draft attention layers."""
if
not
draft_attn_layers
:
raise
ValueError
(
"No attention layers found for ExtractHiddenStatesModel"
)
layer
=
next
(
iter
(
draft_attn_layers
.
values
()))
attn_backend
=
layer
.
get_attn_backend
()
return
attn_backend
.
get_builder_cls
()(
layer
.
get_kv_cache_spec
(
self
.
vllm_config
),
self
.
attn_layer_names
,
self
.
vllm_config
,
self
.
device
,
)
def
prepare_next_token_ids_padded
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampled_token_ids
:
torch
.
Tensor
,
requests
:
dict
[
str
,
CachedRequestState
],
gpu_input_batch
:
InputBatch
,
discard_request_mask
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Prepare next token IDs for speculative decoding.
Since num_speculative_tokens == 1, sampled_token_ids has shape
(batch_size, 1). For each request we either use the sampled token
(if valid and not discarded) or a backup token from the request state.
"""
num_reqs
=
gpu_input_batch
.
num_reqs
device
=
sampled_token_ids
.
device
# Compute backup tokens for discarded / invalid requests
backup_tokens_gpu
=
torch
.
tensor
(
[
requests
[
gpu_input_batch
.
req_ids
[
i
]].
get_token_id
(
common_attn_metadata
.
seq_lens_cpu
[
i
].
item
()
)
for
i
in
range
(
num_reqs
)
],
dtype
=
torch
.
int32
,
device
=
device
,
)
assert
discard_request_mask
.
dtype
==
torch
.
bool
# With num_speculative_tokens == 1, there is exactly one token
sampled
=
sampled_token_ids
[:,
0
]
is_valid
=
(
sampled
>=
0
)
&
(
sampled
<
gpu_input_batch
.
vocab_size
)
valid_sampled_tokens_count
=
is_valid
.
to
(
torch
.
int32
)
use_sampled
=
is_valid
&
~
discard_request_mask
[:
num_reqs
]
next_token_ids
=
torch
.
where
(
use_sampled
,
sampled
.
to
(
torch
.
int32
),
backup_tokens_gpu
)
return
next_token_ids
,
valid_sampled_tokens_count
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
"""Load the ExtractHiddenStatesModel model.
This method instantiates the ExtractHiddenStatesModel model which is used
to cache hidden states during speculative decoding. The model uses
cache-only attention (no computation, just caching KV states).
Args:
target_model: The target model (passed for compatibility with
EagleProposer interface, but not used here)
"""
# Get the target model's attention layers before loading draft model
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
).
keys
()
# type: ignore[type-abstract]
)
assert
self
.
vllm_config
.
speculative_config
is
not
None
draft_model_config
=
self
.
vllm_config
.
speculative_config
.
draft_model_config
from
vllm.compilation.backends
import
set_model_tag
with
set_model_tag
(
"extract_hidden_states"
):
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
draft_model_config
)
# Identify draft model's attention layers (difference from target)
all_attn_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
,
# type: ignore[type-abstract]
)
draft_attn_layers
=
{
name
:
layer
for
name
,
layer
in
all_attn_layers
.
items
()
if
name
not
in
target_attn_layer_names
}
self
.
attn_layer_names
=
list
(
draft_attn_layers
.
keys
())
assert
len
(
draft_attn_layers
)
==
1
,
(
"ExtractHiddenStatesModel should have exactly one "
f
"attention layer, found
{
len
(
draft_attn_layers
)
}
"
)
self
.
attn_metadata_builder
=
self
.
_build_attn_metadata_builder
(
draft_attn_layers
)
def
validate_same_kv_cache_group
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""Validate all drafting layers belong to the same KV cache group.
With exactly one attention layer (asserted in load_model), this is
trivially satisfied.
"""
assert
len
(
self
.
attn_layer_names
)
==
1
vllm/v1/worker/gpu_model_runner.py
View file @
9433acb8
...
@@ -159,6 +159,7 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler
...
@@ -159,6 +159,7 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.suffix_decoding
import
SuffixDecodingProposer
from
vllm.v1.spec_decode.suffix_decoding
import
SuffixDecodingProposer
...
@@ -495,6 +496,7 @@ class GPUModelRunner(
...
@@ -495,6 +496,7 @@ class GPUModelRunner(
|
EagleProposer
|
EagleProposer
|
DraftModelProposer
|
DraftModelProposer
|
MedusaProposer
|
MedusaProposer
|
ExtractHiddenStatesProposer
)
)
if
self
.
speculative_config
.
method
==
"ngram"
:
if
self
.
speculative_config
.
method
==
"ngram"
:
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
...
@@ -518,6 +520,11 @@ class GPUModelRunner(
...
@@ -518,6 +520,11 @@ class GPUModelRunner(
self
.
drafter
=
MedusaProposer
(
self
.
drafter
=
MedusaProposer
(
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
)
)
elif
self
.
speculative_config
.
method
==
"extract_hidden_states"
:
self
.
drafter
=
ExtractHiddenStatesProposer
(
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
)
self
.
use_aux_hidden_state_outputs
=
True
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Unknown speculative decoding method: "
"Unknown speculative decoding method: "
...
@@ -3693,10 +3700,9 @@ class GPUModelRunner(
...
@@ -3693,10 +3700,9 @@ class GPUModelRunner(
def
sample_tokens
(
def
sample_tokens
(
self
,
grammar_output
:
"GrammarOutput | None"
self
,
grammar_output
:
"GrammarOutput | None"
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
IntermediateTensors
:
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
IntermediateTensors
:
if
self
.
execute_model_state
is
None
:
kv_connector_output
=
self
.
kv_connector_output
kv_connector_output
=
self
.
kv_connector_output
self
.
kv_connector_output
=
None
self
.
kv_connector_output
=
None
if
self
.
execute_model_state
is
None
:
# receive sampled token ids from the last PP rank.
# receive sampled token ids from the last PP rank.
if
self
.
use_async_scheduling
and
get_pp_group
().
world_size
>
1
:
if
self
.
use_async_scheduling
and
get_pp_group
().
world_size
>
1
:
self
.
_pp_receive_prev_sampled_token_ids_to_input_batch
()
self
.
_pp_receive_prev_sampled_token_ids_to_input_batch
()
...
@@ -3778,12 +3784,17 @@ class GPUModelRunner(
...
@@ -3778,12 +3784,17 @@ class GPUModelRunner(
<=
self
.
effective_drafter_max_model_len
<=
self
.
effective_drafter_max_model_len
)
)
use_gpu_toks
=
(
use_gpu_toks
=
(
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
()
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
()
or
spec_config
.
uses_extract_hidden_states
()
)
and
not
spec_config
.
disable_padded_drafter_batch
)
and
not
spec_config
.
disable_padded_drafter_batch
if
use_gpu_toks
:
if
use_gpu_toks
:
# EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
# EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
# as inputs, and does not need to wait for bookkeeping to finish.
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
)
sampled_token_ids
=
sampler_output
.
sampled_token_ids
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
input_fits_in_drafter
:
if
input_fits_in_drafter
:
propose_draft_token_ids
(
sampled_token_ids
)
propose_draft_token_ids
(
sampled_token_ids
)
...
@@ -3842,6 +3853,10 @@ class GPUModelRunner(
...
@@ -3842,6 +3853,10 @@ class GPUModelRunner(
with
record_function_or_nullcontext
(
"gpu_model_runner: eplb"
):
with
record_function_or_nullcontext
(
"gpu_model_runner: eplb"
):
self
.
eplb_step
()
self
.
eplb_step
()
# self.kv_connector_output may be modified during drafting
kv_connector_output
=
self
.
kv_connector_output
self
.
kv_connector_output
=
None
with
record_function_or_nullcontext
(
"gpu_model_runner: ModelRunnerOutput"
):
with
record_function_or_nullcontext
(
"gpu_model_runner: ModelRunnerOutput"
):
if
self
.
model_config
.
enable_return_routed_experts
:
if
self
.
model_config
.
enable_return_routed_experts
:
capturer
=
RoutedExpertsCapturer
.
get_instance
()
capturer
=
RoutedExpertsCapturer
.
get_instance
()
...
@@ -4068,6 +4083,48 @@ class GPUModelRunner(
...
@@ -4068,6 +4083,48 @@ class GPUModelRunner(
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
slot_mappings
=
slot_mappings
,
slot_mappings
=
slot_mappings
,
)
)
elif
spec_config
.
uses_extract_hidden_states
():
assert
isinstance
(
self
.
drafter
,
ExtractHiddenStatesProposer
)
assert
isinstance
(
sampled_token_ids
,
torch
.
Tensor
),
(
"sampled_token_ids should be a torch.Tensor for "
"extract_hidden_states method."
)
if
not
self
.
use_aux_hidden_state_outputs
or
aux_hidden_states
is
None
:
raise
ValueError
(
"aux_hidden_states are required when using `extract_hidden_states`"
)
target_hidden_states
=
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
]
draft_token_ids
,
drafter_kv_connector_output
=
self
.
drafter
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
scheduler_output
,
slot_mappings
=
slot_mappings
,
)
# Combine KVConnectorOutputs or select the non-empty one
if
self
.
kv_connector_output
and
drafter_kv_connector_output
:
self
.
kv_connector_output
=
KVConnectorOutput
.
merge
(
self
.
kv_connector_output
,
drafter_kv_connector_output
)
else
:
self
.
kv_connector_output
=
(
self
.
kv_connector_output
or
drafter_kv_connector_output
)
next_token_ids
,
valid_sampled_tokens_count
=
(
self
.
drafter
.
prepare_next_token_ids_padded
(
common_attn_metadata
,
sampled_token_ids
,
self
.
requests
,
self
.
input_batch
,
self
.
discard_request_mask
.
gpu
,
)
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
elif
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
():
elif
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
...
@@ -4946,8 +5003,12 @@ class GPUModelRunner(
...
@@ -4946,8 +5003,12 @@ class GPUModelRunner(
if
self
.
speculative_config
and
(
if
self
.
speculative_config
and
(
self
.
speculative_config
.
use_eagle
()
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_draft_model
()
or
self
.
speculative_config
.
uses_draft_model
()
or
self
.
speculative_config
.
uses_extract_hidden_states
()
):
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
)
assert
self
.
speculative_config
is
not
None
assert
self
.
speculative_config
is
not
None
# Eagle currently only supports PIECEWISE cudagraphs.
# Eagle currently only supports PIECEWISE cudagraphs.
# Therefore only use cudagraphs if the main model uses PIECEWISE
# Therefore only use cudagraphs if the main model uses PIECEWISE
...
@@ -5656,9 +5717,12 @@ class GPUModelRunner(
...
@@ -5656,9 +5717,12 @@ class GPUModelRunner(
cudagraph_mode
,
self
.
uniform_decode_query_len
cudagraph_mode
,
self
.
uniform_decode_query_len
)
)
# Initialize eagle's cudagraph dispatcher if using eagle spec decode.
# Initialize drafter's cudagraph dispatcher if using spec decode.
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
if
self
.
speculative_config
and
(
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_extract_hidden_states
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
ExtractHiddenStatesProposer
)
self
.
drafter
.
initialize_cudagraph_keys
(
cudagraph_mode
)
self
.
drafter
.
initialize_cudagraph_keys
(
cudagraph_mode
)
def
calculate_reorder_batch_threshold
(
self
)
->
None
:
def
calculate_reorder_batch_threshold
(
self
)
->
None
:
...
@@ -6025,8 +6089,12 @@ class GPUModelRunner(
...
@@ -6025,8 +6089,12 @@ class GPUModelRunner(
if
self
.
speculative_config
and
(
if
self
.
speculative_config
and
(
self
.
speculative_config
.
use_eagle
()
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_draft_model
()
or
self
.
speculative_config
.
uses_draft_model
()
or
self
.
speculative_config
.
uses_extract_hidden_states
()
):
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
)
# validate all draft model layers belong to the same kv cache
# validate all draft model layers belong to the same kv cache
# group
# group
self
.
drafter
.
validate_same_kv_cache_group
(
kv_cache_config
)
self
.
drafter
.
validate_same_kv_cache_group
(
kv_cache_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment