Unverified Commit de3869bb authored by Rohan Potdar's avatar Rohan Potdar Committed by GitHub
Browse files

move checks out of `unified_kv_cache_update` custom op (#33943)


Signed-off-by: default avatarRohan138 <rohanpotdar138@gmail.com>
parent ce9b3cd3
......@@ -422,9 +422,15 @@ class Attention(nn.Module, AttentionLayerBase):
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
if self.use_direct_call:
kv_cache_dummy_dep = None
if not self.attn_backend.forward_includes_kv_cache_update:
if self.use_direct_call:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = unified_kv_cache_update(
key, value, self.layer_name
)
......@@ -437,10 +443,12 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
else:
kv_cache_dummy_dep = None
if not self.attn_backend.forward_includes_kv_cache_update and (
# torch can only dispatch custom op if a tensor is passed
key is not None or value is not None
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name
......
......@@ -136,6 +136,9 @@ def create_cross_attention_backend(
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
and layer.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
......
......@@ -172,6 +172,9 @@ def create_whisper_attention_backend_with_block_pooling(
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
and layer.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
......
......@@ -771,16 +771,6 @@ class FlashAttentionImpl(AttentionImpl):
# we use direct Q, K, V tensors without caching
return
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is not None
or key is None
or value is None
):
return
key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache.
......
......@@ -196,16 +196,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
):
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
......
......@@ -383,17 +383,7 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size
)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
......
......@@ -579,13 +579,7 @@ class TritonAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment