Unverified Commit c59a132f authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[V0 Deprecation] Refactor kv cache from list to element (#37487)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent de99d91e
...@@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase): ...@@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
self_kv_cache = self.kv_cache[0] self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
......
...@@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): ...@@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[0] self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
......
...@@ -858,7 +858,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -858,7 +858,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[0] self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
...@@ -1046,7 +1046,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -1046,7 +1046,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
Core attention computation with a packed non-spec decode fast path. Core attention computation with a packed non-spec decode fast path.
""" """
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[0] self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
......
...@@ -481,13 +481,9 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -481,13 +481,9 @@ class AiterFlashAttentionMetadataBuilder(
): ):
layers = get_layers_from_vllm_config(self.vllm_config, Attention) layers = get_layers_from_vllm_config(self.vllm_config, Attention)
first_layer_name = [k for k in layers][0] first_layer_name = [k for k in layers][0]
kv_cache_shape = ( kv_cache_shape = self.vllm_config.compilation_config.static_forward_context[
self.vllm_config.compilation_config.static_forward_context[
first_layer_name first_layer_name
] ].kv_cache.shape
.kv_cache[0]
.shape
)
num_blocks = kv_cache_shape[1] num_blocks = kv_cache_shape[1]
self.scale = torch.ones( self.scale = torch.ones(
[num_blocks, self.num_heads_kv, self.block_size], [num_blocks, self.num_heads_kv, self.block_size],
......
...@@ -5830,7 +5830,10 @@ class GPUModelRunner( ...@@ -5830,7 +5830,10 @@ class GPUModelRunner(
for layer in self.compilation_config.static_forward_context.values(): for layer in self.compilation_config.static_forward_context.values():
if hasattr(layer, "kv_cache"): if hasattr(layer, "kv_cache"):
layer.kv_cache = [] kv_cache = layer.kv_cache
layer.kv_cache = (
torch.tensor([]) if isinstance(kv_cache, torch.Tensor) else []
)
gc.collect() gc.collect()
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
......
...@@ -119,7 +119,7 @@ def collect_mamba_copy_meta( ...@@ -119,7 +119,7 @@ def collect_mamba_copy_meta(
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
for layer_name in layer_names: for layer_name in layer_names:
attention = forward_context[layer_name] attention = forward_context[layer_name]
kv_caches: list[torch.Tensor] = attention.kv_cache[0] kv_caches: list[torch.Tensor] = attention.kv_cache
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs): for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
copy_spec = state_copy_func( copy_spec = state_copy_func(
state, block_ids, src_block_idx, accept_token_bias + 1 state, block_ids, src_block_idx, accept_token_bias + 1
......
...@@ -136,8 +136,8 @@ class KVBlockZeroer: ...@@ -136,8 +136,8 @@ class KVBlockZeroer:
for layer_name in group.layer_names: for layer_name in group.layer_names:
if layer_name in runner_only_attn_layers: if layer_name in runner_only_attn_layers:
continue continue
kv = static_forward_context[layer_name].kv_cache[0] kv = static_forward_context[layer_name].kv_cache
if isinstance(kv, list): if not isinstance(kv, torch.Tensor):
continue continue
dp = kv.data_ptr() dp = kv.data_ptr()
if dp in seen_ptrs: if dp in seen_ptrs:
...@@ -510,8 +510,7 @@ def bind_kv_cache( ...@@ -510,8 +510,7 @@ def bind_kv_cache(
# Bind kv_caches to forward context # Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items(): for layer_name, kv_cache in kv_caches.items():
# NOTE: Keep list wrapper for layers that index kv_cache by engine slot. forward_context[layer_name].kv_cache = kv_cache
forward_context[layer_name].kv_cache = [kv_cache]
def is_residual_scattered_for_sp( def is_residual_scattered_for_sp(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment