Unverified Commit c59a132f authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[V0 Deprecation] Refactor kv cache from list to element (#37487)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent de99d91e
...@@ -127,7 +127,7 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -127,7 +127,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
raw_tensor = raw_tensor.view(kv_cache_shape) raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order) kv_cache = raw_tensor.permute(*inv_order)
self.attn.kv_cache = [kv_cache] self.attn.kv_cache = kv_cache
# Build attn metadata # Build attn metadata
self.attn_metadata = self.builder.build( self.attn_metadata = self.builder.build(
......
...@@ -148,7 +148,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module): ...@@ -148,7 +148,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
raw_tensor = raw_tensor.view(kv_cache_shape) raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order) kv_cache = raw_tensor.permute(*inv_order)
self.attn.kv_cache = [kv_cache] self.attn.kv_cache = kv_cache
# Build attn metadata # Build attn metadata
attn_metadata = self.builder.build( attn_metadata = self.builder.build(
...@@ -295,7 +295,7 @@ def test_rope_kvcache_fusion( ...@@ -295,7 +295,7 @@ def test_rope_kvcache_fusion(
} }
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused) q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
attn_layer = forward_context.no_compile_layers[model.layer_name] attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_unfused = attn_layer.kv_cache[0] kv_cache_unfused = attn_layer.kv_cache
del dummy del dummy
torch._dynamo.mark_dynamic(qkv, 0) torch._dynamo.mark_dynamic(qkv, 0)
...@@ -309,7 +309,7 @@ def test_rope_kvcache_fusion( ...@@ -309,7 +309,7 @@ def test_rope_kvcache_fusion(
} }
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos) q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
attn_layer = forward_context.no_compile_layers[model.layer_name] attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_fused = attn_layer.kv_cache[0] kv_cache_fused = attn_layer.kv_cache
del dummy del dummy
assert fusion_pass.matched_count == 1 assert fusion_pass.matched_count == 1
......
...@@ -258,8 +258,8 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable): ...@@ -258,8 +258,8 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable):
mamba_kv_cache_dict[ mamba_kv_cache_dict[
num_computed_tokens - num_computed_tokens % BLOCK_SIZE num_computed_tokens - num_computed_tokens % BLOCK_SIZE
] = ( ] = (
kv_cache[0][0][block_id].clone(), kv_cache[0][block_id].clone(),
kv_cache[0][1][block_id].clone(), kv_cache[1][block_id].clone(),
) )
last_num_computed_tokens = num_computed_tokens last_num_computed_tokens = num_computed_tokens
...@@ -302,7 +302,7 @@ def get_fake_process_mamba_fn( ...@@ -302,7 +302,7 @@ def get_fake_process_mamba_fn(
mamba_layer_name = kv_cache_config.kv_cache_groups[ mamba_layer_name = kv_cache_config.kv_cache_groups[
mamba_group_id mamba_group_id
].layer_names[0] ].layer_names[0]
mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[0][-1] mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[-1]
mamba_block_table = input_batch.block_table.block_tables[ mamba_block_table = input_batch.block_table.block_tables[
mamba_group_id mamba_group_id
].block_table.cpu[0] ].block_table.cpu[0]
......
...@@ -670,8 +670,8 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config): ...@@ -670,8 +670,8 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config):
runner.initialize_kv_cache(kv_cache_config) runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0] layer_0_kv = vllm_ctx[layer_0].kv_cache
layer_1_kv = vllm_ctx[layer_1].kv_cache[0] layer_1_kv = vllm_ctx[layer_1].kv_cache
# check layer 1 kv cache does NOT share memory with layer 0 # check layer 1 kv cache does NOT share memory with layer 0
assert id(layer_1_kv) != id(layer_0_kv) assert id(layer_1_kv) != id(layer_0_kv)
...@@ -740,8 +740,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config): ...@@ -740,8 +740,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
runner.initialize_kv_cache(kv_cache_config) runner.initialize_kv_cache(kv_cache_config)
kv_cache_config_after_init = runner.kv_cache_config kv_cache_config_after_init = runner.kv_cache_config
layer_0_kv = vllm_ctx[layer_0].kv_cache[0] layer_0_kv = vllm_ctx[layer_0].kv_cache
layer_1_kv = vllm_ctx[layer_1].kv_cache[0] layer_1_kv = vllm_ctx[layer_1].kv_cache
# check layer 1 kv cache shares memory with layer 0 # check layer 1 kv cache shares memory with layer 0
assert id(layer_1_kv) == id(layer_0_kv) assert id(layer_1_kv) == id(layer_0_kv)
...@@ -864,9 +864,9 @@ def test_hybrid_attention_mamba_tensor_shapes(): ...@@ -864,9 +864,9 @@ def test_hybrid_attention_mamba_tensor_shapes():
np.random.shuffle(ind) np.random.shuffle(ind)
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape attn_shape = vllm_ctx[layer_0].kv_cache.shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape conv_shape = vllm_ctx[layer_2].kv_cache[0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape ssm_shape = vllm_ctx[layer_2].kv_cache[1].shape
# assert we are using FlashInfer # assert we are using FlashInfer
assert attn_shape[0] % num_blocks == 0 assert attn_shape[0] % num_blocks == 0
...@@ -905,21 +905,21 @@ def test_hybrid_attention_mamba_tensor_shapes(): ...@@ -905,21 +905,21 @@ def test_hybrid_attention_mamba_tensor_shapes():
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
for layer in [layer_0, layer_1]: for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...] # attention: kv_cache[kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention): for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] vllm_ctx[layer].kv_cache[kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices # fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
# mamba: kv_cache[0][component][kernel_block_idx, ...] # mamba: kv_cache[component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba): for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] vllm_ctx[layer].kv_cache[0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] vllm_ctx[layer].kv_cache[1][kv_block, :] = ssm_blocks_constant[i]
# verify attention and mamba contents are correct # verify attention and mamba contents are correct
for layer in [layer_0, layer_1]: for layer in [layer_0, layer_1]:
for i, kernel_block in enumerate(kernel_blocks_for_attention): for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] actual_kv = vllm_ctx[layer].kv_cache[kernel_block, :]
expected = attn_blocks_constant[i] expected = attn_blocks_constant[i]
# Check K and V separately # Check K and V separately
...@@ -928,8 +928,8 @@ def test_hybrid_attention_mamba_tensor_shapes(): ...@@ -928,8 +928,8 @@ def test_hybrid_attention_mamba_tensor_shapes():
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba): for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :]
expected_conv = conv_blocks_constant[i] expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i] expected_ssm = ssm_blocks_constant[i]
...@@ -938,8 +938,8 @@ def test_hybrid_attention_mamba_tensor_shapes(): ...@@ -938,8 +938,8 @@ def test_hybrid_attention_mamba_tensor_shapes():
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba): for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :]
expected_conv = conv_blocks_constant[i] expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i] expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv) assert torch.equal(actual_conv, expected_conv)
......
...@@ -23,10 +23,10 @@ def test_bind_kv_cache(default_vllm_config): ...@@ -23,10 +23,10 @@ def test_bind_kv_cache(default_vllm_config):
} }
runner_kv_caches: list[torch.Tensor] = [] runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches) bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"] assert ctx["layers.0.self_attn"].kv_cache is kv_cache["layers.0.self_attn"]
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"] assert ctx["layers.1.self_attn"].kv_cache is kv_cache["layers.1.self_attn"]
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"] assert ctx["layers.2.self_attn"].kv_cache is kv_cache["layers.2.self_attn"]
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"] assert ctx["layers.3.self_attn"].kv_cache is kv_cache["layers.3.self_attn"]
assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"] assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"]
assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"] assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"]
...@@ -50,8 +50,8 @@ def test_bind_kv_cache_non_attention(default_vllm_config): ...@@ -50,8 +50,8 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
runner_kv_caches: list[torch.Tensor] = [] runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches) bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"] assert ctx["model.layers.20.attn"].kv_cache is kv_cache["model.layers.20.attn"]
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"] assert ctx["model.layers.28.attn"].kv_cache is kv_cache["model.layers.28.attn"]
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"] assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"] assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
...@@ -74,14 +74,14 @@ def test_bind_kv_cache_draft_model(default_vllm_config): ...@@ -74,14 +74,14 @@ def test_bind_kv_cache_draft_model(default_vllm_config):
runner_kv_caches: list[torch.Tensor] = [] runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches) bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"] assert ctx["model.layers.0.attn"].kv_cache is kv_cache["model.layers.0.attn"]
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"] assert ctx["model.layers.1.attn"].kv_cache is kv_cache["model.layers.1.attn"]
assert ( assert (
ctx["draft_model.layers.0.attn"].kv_cache[0] ctx["draft_model.layers.0.attn"].kv_cache
is kv_cache["draft_model.layers.0.attn"] is kv_cache["draft_model.layers.0.attn"]
) )
assert ( assert (
ctx["draft_model.layers.1.attn"].kv_cache[0] ctx["draft_model.layers.1.attn"].kv_cache
is kv_cache["draft_model.layers.1.attn"] is kv_cache["draft_model.layers.1.attn"]
) )
......
...@@ -181,12 +181,10 @@ class ExampleConnector(KVConnectorBase_V1): ...@@ -181,12 +181,10 @@ class ExampleConnector(KVConnectorBase_V1):
# Only process layers that have kv_cache # Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention # attribute (attention layers) Skip non-attention
# layers like FusedMoE/MLP etc. # layers like FusedMoE/MLP etc.
kv_cache_attr = getattr(layer, "kv_cache", None) kv_cache_layer = getattr(layer, "kv_cache", None)
if kv_cache_attr is None: if kv_cache_layer is None:
continue continue
kv_cache_layer = kv_cache_attr[0]
filename = self._generate_filename_debug( filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes layer_name, request.token_ids, request.mm_hashes
) )
......
...@@ -778,7 +778,7 @@ class LMCacheConnectorV1Impl: ...@@ -778,7 +778,7 @@ class LMCacheConnectorV1Impl:
continue continue
if layer_name not in self.kv_caches: if layer_name not in self.kv_caches:
self.kv_caches[layer_name] = attn_layer.kv_cache[0] self.kv_caches[layer_name] = attn_layer.kv_cache
#################### ####################
# Worker side APIs # Worker side APIs
......
...@@ -214,7 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -214,7 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
if kv_cache is None: if kv_cache is None:
continue continue
layer = kv_cache[0] layer = kv_cache
kv_cache = self.p2p_nccl_engine.recv_tensor( kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address request.request_id + "#" + layer_name, remote_address
......
...@@ -349,10 +349,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -349,10 +349,7 @@ class Attention(nn.Module, AttentionLayerBase):
# use a placeholder kv cache tensor during init, which will be replaced # use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache # by bind_kv_cache
# this variable will not be accessed if use_direct_call is True # this variable will not be accessed if use_direct_call is True
self.kv_cache = [ self.kv_cache = torch.tensor([])
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Initialize KV cache quantization attributes # Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix) _init_kv_cache_quant(self, quant_config, prefix)
...@@ -599,7 +596,7 @@ def get_attention_context( ...@@ -599,7 +596,7 @@ def get_attention_context(
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[0] kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
......
...@@ -415,12 +415,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -415,12 +415,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
self.kv_cache = [ self.kv_cache = torch.tensor([])
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.use_sparse = use_sparse self.use_sparse = use_sparse
...@@ -479,7 +474,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -479,7 +474,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name] attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[0] self_kv_cache = self.kv_cache
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
...@@ -939,7 +934,7 @@ def unified_mla_kv_cache_update( ...@@ -939,7 +934,7 @@ def unified_mla_kv_cache_update(
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
attn_layer = forward_context.no_compile_layers[layer_name] attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[0] kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
......
...@@ -168,7 +168,7 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -168,7 +168,7 @@ class StaticSinkAttention(Attention, CustomOp):
"sink_key and sink_value have not been prepared" "sink_key and sink_value have not been prepared"
) )
if not self.sink_populated: if not self.sink_populated:
self_kv_cache = self.kv_cache[0] self_kv_cache = self.kv_cache
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name) torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
return super().forward(query, key, value, output_shape) return super().forward(query, key, value, output_shape)
......
...@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
constant_caches = self.kv_cache[0] constant_caches = self.kv_cache
q_proj_states = q_proj_states[:num_actual_tokens] q_proj_states = q_proj_states[:num_actual_tokens]
k_proj_states = k_proj_states[:num_actual_tokens] k_proj_states = k_proj_states[:num_actual_tokens]
......
...@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if attn_metadata is not None: if attn_metadata is not None:
kv_cache = self.kv_cache[0][0] kv_cache = self.kv_cache[0]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences( clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata kv_cache, state_indices_tensor, attn_metadata
......
...@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer): ...@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p = attn_metadata.query_start_loc_p query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache[0] self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
......
...@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[0] self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
......
...@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp): ...@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata) assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[0] self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d state_indices_tensor_d = attn_metadata.state_indices_tensor_d
......
...@@ -365,7 +365,7 @@ class SparseAttnIndexer(CustomOp): ...@@ -365,7 +365,7 @@ class SparseAttnIndexer(CustomOp):
return torch.ops.vllm.sparse_attn_indexer( return torch.ops.vllm.sparse_attn_indexer(
hidden_states, hidden_states,
self.k_cache.prefix, self.k_cache.prefix,
self.k_cache.kv_cache[0], self.k_cache.kv_cache,
q_fp8, q_fp8,
k, k,
weights, weights,
...@@ -389,7 +389,7 @@ class SparseAttnIndexer(CustomOp): ...@@ -389,7 +389,7 @@ class SparseAttnIndexer(CustomOp):
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
hidden_states, hidden_states,
self.k_cache.prefix, self.k_cache.prefix,
self.k_cache.kv_cache[0], self.k_cache.kv_cache,
q_fp8, q_fp8,
k, k,
weights, weights,
......
...@@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase): ...@@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase):
# Get KV cache and state indices # Get KV cache and state indices
if attn_metadata is not None: if attn_metadata is not None:
kv_cache = self.kv_cache[0][0] kv_cache = self.kv_cache[0]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences( clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata kv_cache, state_indices_tensor, attn_metadata
......
...@@ -586,7 +586,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): ...@@ -586,7 +586,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
): ):
super().__init__() super().__init__()
self.kv_cache = [torch.tensor([])] self.kv_cache = torch.tensor([])
self.head_dim = head_dim self.head_dim = head_dim
self.prefix = prefix self.prefix = prefix
self.cache_config = cache_config self.cache_config = cache_config
......
...@@ -51,7 +51,7 @@ def unified_kv_cache_update( ...@@ -51,7 +51,7 @@ def unified_kv_cache_update(
""" """
forward_context = get_forward_context() forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name] attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[0] kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
...@@ -288,10 +288,7 @@ class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase): ...@@ -288,10 +288,7 @@ class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
) )
# Placeholder KV cache (replaced by bind_kv_cache) # Placeholder KV cache (replaced by bind_kv_cache)
self.kv_cache = [ self.kv_cache = torch.tensor([])
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Register in compilation context # Register in compilation context
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment