Unverified Commit 79e4937c authored by iefgnoix's avatar iefgnoix Committed by GitHub
Browse files

[v1] Add comments to the new ragged paged attention Pallas kernel (#14155)


Signed-off-by: default avatarXiongfei Wei <isaacwxf23@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent cd1d3c3d
......@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 16
NUM_KV_PAGES_PER_BLOCK = 128
......@@ -154,6 +155,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
# use_kernel switches between using kernel or reference implementation
# (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890).
use_kernel = False
output = torch.ops.xla.ragged_paged_attention(
query,
key_cache,
......@@ -164,7 +168,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
use_kernel=False,
use_kernel=use_kernel,
)
return output.reshape(num_tokens, hidden_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment