Unverified Commit 5ae685c1 authored by Itay Etelis's avatar Itay Etelis Committed by GitHub
Browse files

[Bugfix] Relax TRTLLM KV cache contiguity assertion for cross-layer layout (#34158)


Signed-off-by: default avatarItay Etelis <itay.etelis@ibm.com>
Co-authored-by: default avatarItay Etelis <itay.etelis@ibm.com>
parent ce8cf916
...@@ -586,6 +586,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -586,6 +586,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype # try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata # if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
if ( if (
can_use_trtllm can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization and not vllm_config.attention_config.disable_flashinfer_q_quantization
...@@ -1436,7 +1437,6 @@ class FlashInferImpl(AttentionImpl): ...@@ -1436,7 +1437,6 @@ class FlashInferImpl(AttentionImpl):
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND" assert get_kv_cache_layout() == "HND"
assert is_strictly_contiguous(prefill_query) assert is_strictly_contiguous(prefill_query)
assert is_strictly_contiguous(kv_cache_permute)
assert is_strictly_contiguous(workspace_buffer) assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_prefill) assert is_strictly_contiguous(block_tables_prefill)
assert is_strictly_contiguous(seq_lens_prefill) assert is_strictly_contiguous(seq_lens_prefill)
...@@ -1461,6 +1461,20 @@ class FlashInferImpl(AttentionImpl): ...@@ -1461,6 +1461,20 @@ class FlashInferImpl(AttentionImpl):
# and fp8 kv cache. So to enable prefill attention # and fp8 kv cache. So to enable prefill attention
# with fp8 kv cache, we can construct a mock block # with fp8 kv cache, we can construct a mock block
# and mock kv cache with BF16 KV involved in the prefill # and mock kv cache with BF16 KV involved in the prefill
#
# The inner (block_size, head_size) dims must be
# contiguous; outer dims may have non-canonical strides
# (e.g. cross-layer unified allocation).
# Degenerate strides on outer dims break TMA descriptors
# (see flashinfer-ai/flashinfer#2232).
kv_strides = kv_cache_permute.stride()
assert (
kv_strides[-1] == 1
and kv_strides[-2] == kv_cache_permute.shape[-1]
), (
"KV cache inner dims (block_size, head_size) must be "
f"contiguous, got strides {kv_strides}"
)
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
kv_cache_permute, kv_cache_permute,
block_tables_prefill, block_tables_prefill,
...@@ -1549,10 +1563,21 @@ class FlashInferImpl(AttentionImpl): ...@@ -1549,10 +1563,21 @@ class FlashInferImpl(AttentionImpl):
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND" assert get_kv_cache_layout() == "HND"
assert is_strictly_contiguous(decode_query) assert is_strictly_contiguous(decode_query)
assert is_strictly_contiguous(kv_cache_permute)
assert is_strictly_contiguous(workspace_buffer) assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_decode) assert is_strictly_contiguous(block_tables_decode)
assert is_strictly_contiguous(seq_lens_decode) assert is_strictly_contiguous(seq_lens_decode)
# kv_cache outer dims may be non-contiguous (e.g.
# cross-layer unified allocation), but inner dims
# (block_size, head_size) must be contiguous and
# strides must be canonical to avoid TMA descriptor
# failures (see flashinfer-ai/flashinfer#2232).
kv_strides = kv_cache_permute.stride()
assert (
kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1]
), (
"KV cache inner dims (block_size, head_size) must be "
f"contiguous, got strides {kv_strides}"
)
if output.dtype == FP4_DTYPE: if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None assert self.o_sf_scale is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment