Unverified Commit c59a132f authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[V0 Deprecation] Refactor kv cache from list to element (#37487)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent de99d91e
......@@ -127,7 +127,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order)
self.attn.kv_cache = [kv_cache]
self.attn.kv_cache = kv_cache
# Build attn metadata
self.attn_metadata = self.builder.build(
......
......@@ -148,7 +148,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order)
self.attn.kv_cache = [kv_cache]
self.attn.kv_cache = kv_cache
# Build attn metadata
attn_metadata = self.builder.build(
......@@ -295,7 +295,7 @@ def test_rope_kvcache_fusion(
}
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_unfused = attn_layer.kv_cache[0]
kv_cache_unfused = attn_layer.kv_cache
del dummy
torch._dynamo.mark_dynamic(qkv, 0)
......@@ -309,7 +309,7 @@ def test_rope_kvcache_fusion(
}
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_fused = attn_layer.kv_cache[0]
kv_cache_fused = attn_layer.kv_cache
del dummy
assert fusion_pass.matched_count == 1
......
......@@ -258,8 +258,8 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable):
mamba_kv_cache_dict[
num_computed_tokens - num_computed_tokens % BLOCK_SIZE
] = (
kv_cache[0][0][block_id].clone(),
kv_cache[0][1][block_id].clone(),
kv_cache[0][block_id].clone(),
kv_cache[1][block_id].clone(),
)
last_num_computed_tokens = num_computed_tokens
......@@ -302,7 +302,7 @@ def get_fake_process_mamba_fn(
mamba_layer_name = kv_cache_config.kv_cache_groups[
mamba_group_id
].layer_names[0]
mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[0][-1]
mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[-1]
mamba_block_table = input_batch.block_table.block_tables[
mamba_group_id
].block_table.cpu[0]
......
......@@ -670,8 +670,8 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config):
runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
layer_0_kv = vllm_ctx[layer_0].kv_cache
layer_1_kv = vllm_ctx[layer_1].kv_cache
# check layer 1 kv cache does NOT share memory with layer 0
assert id(layer_1_kv) != id(layer_0_kv)
......@@ -740,8 +740,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
runner.initialize_kv_cache(kv_cache_config)
kv_cache_config_after_init = runner.kv_cache_config
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
layer_0_kv = vllm_ctx[layer_0].kv_cache
layer_1_kv = vllm_ctx[layer_1].kv_cache
# check layer 1 kv cache shares memory with layer 0
assert id(layer_1_kv) == id(layer_0_kv)
......@@ -864,9 +864,9 @@ def test_hybrid_attention_mamba_tensor_shapes():
np.random.shuffle(ind)
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
attn_shape = vllm_ctx[layer_0].kv_cache.shape
conv_shape = vllm_ctx[layer_2].kv_cache[0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[1].shape
# assert we are using FlashInfer
assert attn_shape[0] % num_blocks == 0
......@@ -905,21 +905,21 @@ def test_hybrid_attention_mamba_tensor_shapes():
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
# attention: kv_cache[kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
vllm_ctx[layer].kv_cache[kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]:
# mamba: kv_cache[0][component][kernel_block_idx, ...]
# mamba: kv_cache[component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[1][kv_block, :] = ssm_blocks_constant[i]
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
actual_kv = vllm_ctx[layer].kv_cache[kernel_block, :]
expected = attn_blocks_constant[i]
# Check K and V separately
......@@ -928,8 +928,8 @@ def test_hybrid_attention_mamba_tensor_shapes():
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
......@@ -938,8 +938,8 @@ def test_hybrid_attention_mamba_tensor_shapes():
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
......
......@@ -23,10 +23,10 @@ def test_bind_kv_cache(default_vllm_config):
}
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"]
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"]
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"]
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"]
assert ctx["layers.0.self_attn"].kv_cache is kv_cache["layers.0.self_attn"]
assert ctx["layers.1.self_attn"].kv_cache is kv_cache["layers.1.self_attn"]
assert ctx["layers.2.self_attn"].kv_cache is kv_cache["layers.2.self_attn"]
assert ctx["layers.3.self_attn"].kv_cache is kv_cache["layers.3.self_attn"]
assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"]
assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"]
......@@ -50,8 +50,8 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"]
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"]
assert ctx["model.layers.20.attn"].kv_cache is kv_cache["model.layers.20.attn"]
assert ctx["model.layers.28.attn"].kv_cache is kv_cache["model.layers.28.attn"]
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
......@@ -74,14 +74,14 @@ def test_bind_kv_cache_draft_model(default_vllm_config):
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
assert ctx["model.layers.0.attn"].kv_cache is kv_cache["model.layers.0.attn"]
assert ctx["model.layers.1.attn"].kv_cache is kv_cache["model.layers.1.attn"]
assert (
ctx["draft_model.layers.0.attn"].kv_cache[0]
ctx["draft_model.layers.0.attn"].kv_cache
is kv_cache["draft_model.layers.0.attn"]
)
assert (
ctx["draft_model.layers.1.attn"].kv_cache[0]
ctx["draft_model.layers.1.attn"].kv_cache
is kv_cache["draft_model.layers.1.attn"]
)
......
......@@ -181,12 +181,10 @@ class ExampleConnector(KVConnectorBase_V1):
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE/MLP etc.
kv_cache_attr = getattr(layer, "kv_cache", None)
if kv_cache_attr is None:
kv_cache_layer = getattr(layer, "kv_cache", None)
if kv_cache_layer is None:
continue
kv_cache_layer = kv_cache_attr[0]
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
......
......@@ -778,7 +778,7 @@ class LMCacheConnectorV1Impl:
continue
if layer_name not in self.kv_caches:
self.kv_caches[layer_name] = attn_layer.kv_cache[0]
self.kv_caches[layer_name] = attn_layer.kv_cache
####################
# Worker side APIs
......
......@@ -214,7 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
if kv_cache is None:
continue
layer = kv_cache[0]
layer = kv_cache
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address
......
......@@ -349,10 +349,7 @@ class Attention(nn.Module, AttentionLayerBase):
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
self.kv_cache = torch.tensor([])
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)
......@@ -599,7 +596,7 @@ def get_attention_context(
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[0]
kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
......
......@@ -415,12 +415,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.kv_cache = torch.tensor([])
self.use_sparse = use_sparse
......@@ -479,7 +474,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
......@@ -939,7 +934,7 @@ def unified_mla_kv_cache_update(
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[0]
kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
......
......@@ -168,7 +168,7 @@ class StaticSinkAttention(Attention, CustomOp):
"sink_key and sink_value have not been prepared"
)
if not self.sink_populated:
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
return super().forward(query, key, value, output_shape)
......
......@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
num_actual_tokens = attn_metadata.num_actual_tokens
constant_caches = self.kv_cache[0]
constant_caches = self.kv_cache
q_proj_states = q_proj_states[:num_actual_tokens]
k_proj_states = k_proj_states[:num_actual_tokens]
......
......@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if attn_metadata is not None:
kv_cache = self.kv_cache[0][0]
kv_cache = self.kv_cache[0]
state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata
......
......@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
......
......@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
......
......@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
......
......@@ -365,7 +365,7 @@ class SparseAttnIndexer(CustomOp):
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
self.k_cache.kv_cache,
q_fp8,
k,
weights,
......@@ -389,7 +389,7 @@ class SparseAttnIndexer(CustomOp):
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
self.k_cache.kv_cache,
q_fp8,
k,
weights,
......
......@@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase):
# Get KV cache and state indices
if attn_metadata is not None:
kv_cache = self.kv_cache[0][0]
kv_cache = self.kv_cache[0]
state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata
......
......@@ -586,7 +586,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
):
super().__init__()
self.kv_cache = [torch.tensor([])]
self.kv_cache = torch.tensor([])
self.head_dim = head_dim
self.prefix = prefix
self.cache_config = cache_config
......
......@@ -51,7 +51,7 @@ def unified_kv_cache_update(
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[0]
kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
......@@ -288,10 +288,7 @@ class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
)
# Placeholder KV cache (replaced by bind_kv_cache)
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
self.kv_cache = torch.tensor([])
# Register in compilation context
compilation_config = vllm_config.compilation_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment