Unverified Commit 6101a26d authored by Vadim Gimpelson's avatar Vadim Gimpelson Committed by GitHub
Browse files

[BUGFIX] Fix degenerate strides in TRTLLM query tensors for FlashInfer...


[BUGFIX]  Fix degenerate strides in TRTLLM query tensors for FlashInfer backend. Fixes issue #32353 (#32417)
Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@gmail.com>
parent f5d17400
......@@ -1385,8 +1385,11 @@ class FlashInferImpl(AttentionImpl):
)
else:
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
# prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous()
# prefill_query may be non-contiguous or have degenerate strides
# First ensure memory contiguity, then fix degenerate strides
# with reshape. contiguous() alone doesn't fix degenerate
# strides when a dimension has size 1.
prefill_query = prefill_query.contiguous().reshape(prefill_query.shape)
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.prefill.block_tables
seq_lens_prefill = attn_metadata.prefill.seq_lens
......@@ -1495,9 +1498,12 @@ class FlashInferImpl(AttentionImpl):
out=output[:num_decode_tokens],
)
else:
# decode_query may be non-contiguous
# decode_query may be non-contiguous or have degenerate strides
assert isinstance(attn_metadata.decode, TRTLLMDecode)
decode_query = decode_query.contiguous()
# First ensure memory contiguity, then fix degenerate strides
# with reshape. contiguous() alone doesn't fix degenerate
# strides when a dimension has size 1.
decode_query = decode_query.contiguous().reshape(decode_query.shape)
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_decode = attn_metadata.decode.block_tables
seq_lens_decode = attn_metadata.decode.seq_lens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment