Unverified Commit 22ad6495 authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[ROCm] Enabling forward_includes_kv_cache on ROCm MHA backends (#33106)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
parent 36d450e3
......@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backend import AttentionLayer, AttentionType
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend,
......@@ -24,6 +24,8 @@ logger = init_logger(__name__)
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "ROCM_AITER_UNIFIED_ATTN"
......@@ -142,27 +144,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
......@@ -204,3 +185,34 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
)
return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
......@@ -18,6 +18,7 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
......@@ -193,6 +194,8 @@ class RocmAttentionBackend(AttentionBackend):
"FlexAttention backend which supports all head sizes."
)
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "ROCM_ATTN"
......@@ -330,6 +333,56 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens] if key is not None else None,
value=value[:num_actual_tokens] if value is not None else None,
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
)
return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
......@@ -354,7 +407,7 @@ class RocmAttentionImpl(AttentionImpl):
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
......@@ -367,46 +420,8 @@ class RocmAttentionImpl(AttentionImpl):
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens] if key is not None else None,
value=value[:num_actual_tokens] if value is not None else None,
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
)
return output
......@@ -21,6 +21,7 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
......@@ -271,6 +272,8 @@ class TritonAttentionBackend(AttentionBackend):
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "TRITON_ATTN"
......@@ -461,31 +464,6 @@ class TritonAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
......@@ -585,3 +563,38 @@ class TritonAttentionImpl(AttentionImpl):
sliding_window_k=self.sliding_window[1],
)
return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment