Unverified Commit de3869bb authored by Rohan Potdar's avatar Rohan Potdar Committed by GitHub
Browse files

move checks out of `unified_kv_cache_update` custom op (#33943)


Signed-off-by: default avatarRohan138 <rohanpotdar138@gmail.com>
parent ce9b3cd3
...@@ -422,9 +422,15 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -422,9 +422,15 @@ class Attention(nn.Module, AttentionLayerBase):
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None: if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v) value = value.view(-1, self.num_kv_heads, self.head_size_v)
kv_cache_dummy_dep = None
if self.use_direct_call: if self.use_direct_call:
kv_cache_dummy_dep = None # Skip this if sharing KV cache with an earlier attention layer.
if not self.attn_backend.forward_includes_kv_cache_update: if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = unified_kv_cache_update( kv_cache_dummy_dep = unified_kv_cache_update(
key, value, self.layer_name key, value, self.layer_name
) )
...@@ -437,10 +443,12 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -437,10 +443,12 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
else: else:
kv_cache_dummy_dep = None # Skip this if sharing KV cache with an earlier attention layer.
if not self.attn_backend.forward_includes_kv_cache_update and ( if (
# torch can only dispatch custom op if a tensor is passed not self.attn_backend.forward_includes_kv_cache_update
key is not None or value is not None and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
): ):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name key, value, self.layer_name
......
...@@ -136,6 +136,9 @@ def create_cross_attention_backend( ...@@ -136,6 +136,9 @@ def create_cross_attention_backend(
if ( if (
not underlying_attn_backend.forward_includes_kv_cache_update not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None and attn_metadata is not None
and layer.kv_sharing_target_layer_name is None
and key is not None
and value is not None
): ):
self.do_kv_cache_update( self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping layer, key, value, kv_cache, attn_metadata.slot_mapping
......
...@@ -172,6 +172,9 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -172,6 +172,9 @@ def create_whisper_attention_backend_with_block_pooling(
if ( if (
not underlying_attn_backend.forward_includes_kv_cache_update not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None and attn_metadata is not None
and layer.kv_sharing_target_layer_name is None
and key is not None
and value is not None
): ):
self.do_kv_cache_update( self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping layer, key, value, kv_cache, attn_metadata.slot_mapping
......
...@@ -771,16 +771,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -771,16 +771,6 @@ class FlashAttentionImpl(AttentionImpl):
# we use direct Q, K, V tensors without caching # we use direct Q, K, V tensors without caching
return return
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is not None
or key is None
or value is None
):
return
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
......
...@@ -196,23 +196,14 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ...@@ -196,23 +196,14 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
): ):
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are # Reshape the input keys and values and store them in the cache.
# calculated once based on the output from the encoder and then cached ops.reshape_and_cache_flash(
# in KV cache. key,
if ( value,
self.kv_sharing_target_layer_name is None key_cache,
and key is not None value_cache,
and value is not None slot_mapping,
): self.kv_cache_dtype,
# Reshape the input keys and values and store them in the cache. layer._k_scale,
# Skip this if sharing KV cache with an earlier attention layer. layer._v_scale,
ops.reshape_and_cache_flash( )
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
...@@ -383,45 +383,35 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -383,45 +383,35 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size kv_cache, self.num_kv_heads, self.head_size
) )
# key and value may be None in the case of cross attention. They are # Reshape the input keys and values and store them in the cache.
# calculated once based on the output from the encoder and then cached # Get the actual block_size from value_cache
# in KV cache. # value_cache shape: [num_blocks, num_heads, head_size, block_size]
if ( block_size = value_cache.shape[3]
self.kv_sharing_target_layer_name is None # Determine if it is a power of 2
and key is not None is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
and value is not None
): if is_pow2:
# Reshape the input keys and values and store them in the cache. # Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
# Skip this if sharing KV cache with an earlier attention layer. PagedAttention.write_to_paged_cache(
key,
# Get the actual block_size from value_cache value,
# value_cache shape: [num_blocks, num_heads, head_size, block_size] key_cache,
block_size = value_cache.shape[3] value_cache,
# Determine if it is a power of 2 slot_mapping,
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) self.kv_cache_dtype,
layer._k_scale,
if is_pow2: layer._v_scale,
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic )
PagedAttention.write_to_paged_cache( else:
key, # Case B: Non-standard blocks (e.g., 544 in Qwen3),
value, # force using our modified Triton logic
key_cache, triton_reshape_and_cache_flash(
value_cache, key,
slot_mapping, value,
self.kv_cache_dtype, key_cache,
layer._k_scale, value_cache,
layer._v_scale, slot_mapping,
) self.kv_cache_dtype,
else: layer._k_scale,
# Case B: Non-standard blocks (e.g., 544 in Qwen3), layer._v_scale,
# force using our modified Triton logic )
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
...@@ -579,26 +579,20 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -579,26 +579,20 @@ class TritonAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before # For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1) key_cache, value_cache = kv_cache.unbind(1)
if ( # Reshape the input keys and values and store them in the cache.
self.kv_sharing_target_layer_name is None if self.kv_cache_dtype.startswith("fp8"):
and key is not None key_cache = key_cache.view(self.fp8_dtype)
and value is not None value_cache = value_cache.view(self.fp8_dtype)
): # triton kernel does not support uint8 kv_cache
# Reshape the input keys and values and store them in the cache. # (because some explicit casts (e.g. float8_e4m3fnuz)
# Skip this if sharing KV cache with an earlier attention layer. # are not supported)
if self.kv_cache_dtype.startswith("fp8"): triton_reshape_and_cache_flash(
key_cache = key_cache.view(self.fp8_dtype) key,
value_cache = value_cache.view(self.fp8_dtype) value,
# triton kernel does not support uint8 kv_cache key_cache,
# (because some explicit casts (e.g. float8_e4m3fnuz) value_cache,
# are not supported) slot_mapping,
triton_reshape_and_cache_flash( self.kv_cache_dtype,
key, layer._k_scale,
value, layer._v_scale,
key_cache, )
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment