Unverified Commit 04bf5a35 authored by Fynn Schmitt-Ulms's avatar Fynn Schmitt-Ulms Committed by GitHub
Browse files

[Spec Decode] Update extract_hidden_states to use deferred kv_connector clear (#37013)

parent 43a73f85
...@@ -252,29 +252,22 @@ def test_propose(): ...@@ -252,29 +252,22 @@ def test_propose():
] ]
# Sampled token IDs from target model # Sampled token IDs from target model
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device) sampled_token_ids = torch.tensor(
[42, 60], dtype=torch.int32, device=device
# Mock scheduler output ).unsqueeze(-1)
mock_scheduler_output = mock.MagicMock()
# Call propose # Call propose
with mock.patch( draft_tokens = proposer.propose(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
) as mock_has_kv:
mock_has_kv.return_value = False
draft_tokens, kv_connector_output = proposer.propose(
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
scheduler_output=mock_scheduler_output,
slot_mappings=None, slot_mappings=None,
) )
# Verify draft tokens match sampled tokens # Verify draft tokens match sampled tokens
# Shape should be [batch_size, 1] for num_speculative_tokens=1 # Shape should be [batch_size, 1] for num_speculative_tokens=1
assert draft_tokens.shape == (batch_size, 1) assert draft_tokens.shape == (batch_size, 1)
assert torch.equal(draft_tokens[:, 0], sampled_token_ids) assert torch.equal(draft_tokens, sampled_token_ids)
# Verify the model was called # Verify the model was called
model_mock.assert_called_once() model_mock.assert_called_once()
...@@ -326,21 +319,16 @@ def test_propose_different_layer_counts(num_hidden_layers): ...@@ -326,21 +319,16 @@ def test_propose_different_layer_counts(num_hidden_layers):
for _ in range(num_hidden_layers) for _ in range(num_hidden_layers)
] ]
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device) sampled_token_ids = torch.tensor(
mock_scheduler_output = mock.MagicMock() [42, 60], dtype=torch.int32, device=device
).unsqueeze(-1)
with mock.patch(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
) as mock_has_kv:
mock_has_kv.return_value = False
draft_tokens, _ = proposer.propose( draft_tokens = proposer.propose(
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
scheduler_output=mock_scheduler_output,
slot_mappings=None, slot_mappings=None,
) )
assert draft_tokens.shape == (batch_size, 1) assert draft_tokens.shape == (batch_size, 1)
assert torch.equal(draft_tokens[:, 0], sampled_token_ids) assert torch.equal(draft_tokens, sampled_token_ids)
...@@ -286,7 +286,9 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1): ...@@ -286,7 +286,9 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1):
cached_req = self._active_requests[req_id] cached_req = self._active_requests[req_id]
req_block_ids = self._req_blocks[req_id] req_block_ids = self._req_blocks[req_id]
assert new_block_ids is not None if new_block_ids is None:
continue
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
req_block_ids.extend(block_ids) req_block_ids.extend(block_ids)
......
...@@ -3,26 +3,21 @@ ...@@ -3,26 +3,21 @@
from __future__ import annotations from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
PADDING_SLOT_ID = -1 PADDING_SLOT_ID = -1
...@@ -79,11 +74,10 @@ class ExtractHiddenStatesProposer: ...@@ -79,11 +74,10 @@ class ExtractHiddenStatesProposer:
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
target_hidden_states: list[torch.Tensor], target_hidden_states: list[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
scheduler_output: SchedulerOutput,
slot_mappings: dict[str, torch.Tensor] slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]] | list[dict[str, torch.Tensor]]
| None = None, | None = None,
) -> tuple[torch.Tensor, KVConnectorOutput | None]: ) -> torch.Tensor:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model. """Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache The ExtractHiddenStatesModel caches the hidden states in the KV cache
...@@ -99,7 +93,6 @@ class ExtractHiddenStatesProposer: ...@@ -99,7 +93,6 @@ class ExtractHiddenStatesProposer:
target_hidden_states: List of hidden state tensors from target model target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer) (one per aux hidden state layer)
common_attn_metadata: Attention metadata common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility) interface compatibility)
...@@ -136,8 +129,7 @@ class ExtractHiddenStatesProposer: ...@@ -136,8 +129,7 @@ class ExtractHiddenStatesProposer:
if num_tokens_across_dp is not None: if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens num_tokens_across_dp[self.dp_rank] = num_input_tokens
with ( with set_forward_context(
set_forward_context(
per_layer_attn_metadata, per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
...@@ -146,12 +138,6 @@ class ExtractHiddenStatesProposer: ...@@ -146,12 +138,6 @@ class ExtractHiddenStatesProposer:
slot_mapping=self._get_slot_mapping( slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping num_input_tokens, common_attn_metadata.slot_mapping
), ),
),
(
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
) as kv_connector_output,
): ):
self.model( self.model(
hidden_states=self.hidden_states[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens],
...@@ -159,7 +145,7 @@ class ExtractHiddenStatesProposer: ...@@ -159,7 +145,7 @@ class ExtractHiddenStatesProposer:
# Return the sampled tokens as "draft" tokens # Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1 # Shape: [batch_size, 1] to match num_speculative_tokens=1
return sampled_token_ids.unsqueeze(-1), kv_connector_output return sampled_token_ids
def _get_slot_mapping( def _get_slot_mapping(
self, self,
......
...@@ -4328,23 +4328,12 @@ class GPUModelRunner( ...@@ -4328,23 +4328,12 @@ class GPUModelRunner(
) )
target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
draft_token_ids, drafter_kv_connector_output = self.drafter.propose( draft_token_ids = self.drafter.propose(
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
scheduler_output=scheduler_output,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
) )
# Combine KVConnectorOutputs or select the non-empty one
if self.kv_connector_output and drafter_kv_connector_output:
self.kv_connector_output = KVConnectorOutput.merge(
self.kv_connector_output, drafter_kv_connector_output
)
else:
self.kv_connector_output = (
self.kv_connector_output or drafter_kv_connector_output
)
next_token_ids, valid_sampled_tokens_count = ( next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded( self.drafter.prepare_next_token_ids_padded(
common_attn_metadata, common_attn_metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment