Unverified Commit c3123207 authored by Huamin Li's avatar Huamin Li Committed by GitHub
Browse files

[CI/Build] tests(v1): feed Triton attention the (num_blocks, 2, …) KV cache...


[CI/Build] tests(v1): feed Triton attention the (num_blocks, 2, …) KV cache layout in backend-correctness tests (#26663)
Signed-off-by: default avatarHuamin Li <3ericli@gmail.com>
Co-authored-by: default avatarYe (Charlotte) Qi <yeq@meta.com>
parent c981f0ea
...@@ -423,13 +423,14 @@ def _test_backend_correctness( ...@@ -423,13 +423,14 @@ def _test_backend_correctness(
for backend_name in backend_to_test: for backend_name in backend_to_test:
# FlashAttentionm + FlexAttention: # FlashAttentionm + FlexAttention:
# [2, num_blocks, block_size, num_kv_heads, head_size] # [2, num_blocks, block_size, num_kv_heads, head_size]
# FlashInfer: # FlashInfer + Triton:
# [num_blocks, 2, block_size, num_kv_heads, head_size] # [num_blocks, 2, block_size, num_kv_heads, head_size]
# Select the appropriate KV cache format for each backend # Select the appropriate KV cache format for each backend
kv_cache_for_backend = kv_cache kv_cache_for_backend = kv_cache
if backend_name == _Backend.FLASHINFER: if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
kv_cache_for_backend = kv_cache.transpose(0, 1) kv_cache_for_backend = kv_cache.transpose(0, 1)
if backend_name == _Backend.FLASHINFER:
# For FlashInfer default to HND layout and # For FlashInfer default to HND layout and
kv_cache_for_backend = ( kv_cache_for_backend = (
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment