"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "3f60f2244e3ffec6198d7a41765918d1efd3bb96"
Unverified Commit c4949772 authored by Samu Tamminen's avatar Samu Tamminen Committed by GitHub
Browse files

[ROCm][perf] Shuffle KV cache to use paged_attention_common (#32914)


Signed-off-by: default avatarSamu Tamminen <stammine@amd.com>
Co-authored-by: default avatarTuukka Sarvi <tuukka.sarvi@amd.com>
parent cb0b4432
...@@ -2070,5 +2070,56 @@ class rocm_aiter_ops: ...@@ -2070,5 +2070,56 @@ class rocm_aiter_ops:
out_=out_, out_=out_,
) )
@staticmethod
def paged_attention_common(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
tmp_out: torch.Tensor,
max_logits: torch.Tensor,
exp_sums: torch.Tensor,
max_seq_len: int,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_tables_stride0: int,
scale: float,
K_QScale_hip: torch.Tensor,
V_QScale_hip: torch.Tensor,
K_QScale_asm: torch.Tensor,
V_QScale_asm: torch.Tensor,
out_: torch.Tensor,
kv_cache_dtype: str,
):
"""
Paged attention common function.
This function is NOT wrapped with @is_aiter_supported decorator
to allow explicit backend selection via attention_config to work
even when VLLM_ROCM_USE_AITER=0.
Note: This performs lazy import of aiter.paged_attention_common
"""
from aiter import paged_attention_common
return paged_attention_common(
Q=Q,
K=K,
V=V,
tmp_out=tmp_out,
max_logits=max_logits,
exp_sums=exp_sums,
max_seq_len=max_seq_len,
block_tables=block_tables,
context_lens=context_lens,
block_tables_stride0=block_tables_stride0,
scale=scale,
K_QScale_hip=K_QScale_hip,
V_QScale_hip=V_QScale_hip,
K_QScale_asm=K_QScale_asm,
V_QScale_asm=V_QScale_asm,
out_=out_,
kv_cache_dtype=kv_cache_dtype,
)
rocm_aiter_ops.register_ops_once() rocm_aiter_ops.register_ops_once()
...@@ -1247,7 +1247,23 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1247,7 +1247,23 @@ class AiterFlashAttentionImpl(AttentionImpl):
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
) )
elif rocm_aiter_ops.is_shuffle_kv_cache_enabled(): elif rocm_aiter_ops.is_shuffle_kv_cache_enabled():
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape _, num_heads, head_size = query.shape
num_seqs = attn_metadata.seq_lens.shape[0]
max_num_partitions = (
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
tmp_out = torch.empty(
(num_seqs, num_heads, max_num_partitions, head_size),
dtype=query.dtype,
device=query.device,
)
exp_sums = torch.empty(
(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=query.device,
)
max_logits = torch.empty_like(exp_sums)
num_blocks, block_size, num_kv_heads, _ = key_cache.shape
x = 16 // key_cache.element_size() x = 16 // key_cache.element_size()
k_cache_template = torch.empty( k_cache_template = torch.empty(
[num_blocks, num_kv_heads, head_size // x, block_size, x], [num_blocks, num_kv_heads, head_size // x, block_size, x],
...@@ -1261,18 +1277,36 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1261,18 +1277,36 @@ class AiterFlashAttentionImpl(AttentionImpl):
) )
new_key_cache = key_cache.view_as(k_cache_template) new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template) new_value_cache = value_cache.view_as(v_cache_template)
rocm_aiter_ops.pa_fwd_asm( k_qscale = (
layer._k_scale
if attn_metadata.k_scale is None
else attn_metadata.k_scale
)
v_qscale = (
layer._v_scale
if attn_metadata.v_scale is None
else attn_metadata.v_scale
)
rocm_aiter_ops.paged_attention_common(
Q=query[:num_decode_tokens], Q=query[:num_decode_tokens],
K=new_key_cache, K=new_key_cache,
V=new_value_cache, V=new_value_cache,
tmp_out=tmp_out,
max_logits=max_logits,
exp_sums=exp_sums,
max_seq_len=attn_metadata.max_seq_len,
block_tables=attn_metadata.block_table[:num_decodes], block_tables=attn_metadata.block_table[:num_decodes],
context_lens=attn_metadata.seq_lens[:num_decodes], context_lens=attn_metadata.seq_lens[:num_decodes],
block_tables_stride0=attn_metadata.block_table[ block_tables_stride0=attn_metadata.block_table[
:num_decodes :num_decodes
].stride(0), ].stride(0),
K_QScale=attn_metadata.k_scale, scale=self.scale,
V_QScale=attn_metadata.v_scale, K_QScale_hip=k_qscale,
V_QScale_hip=v_qscale,
K_QScale_asm=k_qscale,
V_QScale_asm=v_qscale,
out_=output[:num_decode_tokens], out_=output[:num_decode_tokens],
kv_cache_dtype=self.kv_cache_dtype,
) )
else: else:
_, num_heads, head_size = query.shape _, num_heads, head_size = query.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment