Unverified Commit c59a132f authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[V0 Deprecation] Refactor kv cache from list to element (#37487)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent de99d91e
......@@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
......
......@@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
......
......@@ -858,7 +858,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
......@@ -1046,7 +1046,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
Core attention computation with a packed non-spec decode fast path.
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
......
......@@ -481,13 +481,9 @@ class AiterFlashAttentionMetadataBuilder(
):
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
first_layer_name = [k for k in layers][0]
kv_cache_shape = (
self.vllm_config.compilation_config.static_forward_context[
kv_cache_shape = self.vllm_config.compilation_config.static_forward_context[
first_layer_name
]
.kv_cache[0]
.shape
)
].kv_cache.shape
num_blocks = kv_cache_shape[1]
self.scale = torch.ones(
[num_blocks, self.num_heads_kv, self.block_size],
......
......@@ -5830,7 +5830,10 @@ class GPUModelRunner(
for layer in self.compilation_config.static_forward_context.values():
if hasattr(layer, "kv_cache"):
layer.kv_cache = []
kv_cache = layer.kv_cache
layer.kv_cache = (
torch.tensor([]) if isinstance(kv_cache, torch.Tensor) else []
)
gc.collect()
torch.accelerator.empty_cache()
......
......@@ -119,7 +119,7 @@ def collect_mamba_copy_meta(
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
for layer_name in layer_names:
attention = forward_context[layer_name]
kv_caches: list[torch.Tensor] = attention.kv_cache[0]
kv_caches: list[torch.Tensor] = attention.kv_cache
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
copy_spec = state_copy_func(
state, block_ids, src_block_idx, accept_token_bias + 1
......
......@@ -136,8 +136,8 @@ class KVBlockZeroer:
for layer_name in group.layer_names:
if layer_name in runner_only_attn_layers:
continue
kv = static_forward_context[layer_name].kv_cache[0]
if isinstance(kv, list):
kv = static_forward_context[layer_name].kv_cache
if not isinstance(kv, torch.Tensor):
continue
dp = kv.data_ptr()
if dp in seen_ptrs:
......@@ -510,8 +510,7 @@ def bind_kv_cache(
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Keep list wrapper for layers that index kv_cache by engine slot.
forward_context[layer_name].kv_cache = [kv_cache]
forward_context[layer_name].kv_cache = kv_cache
def is_residual_scattered_for_sp(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment