Unverified Commit 6ef1efd5 authored by aditi-amd's avatar aditi-amd Committed by GitHub
Browse files

[ROCm] Fix TurboQuant on ROCm: backend routing, flash-attn compat, int64 overflow (#39953)


Signed-off-by: default avataraditi <aditi.rana@amd.com>
parent 58da4ee0
...@@ -382,6 +382,7 @@ def _get_backend_priorities( ...@@ -382,6 +382,7 @@ def _get_backend_priorities(
if is_aiter_found_and_supported(): if is_aiter_found_and_supported():
backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN) backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)
backends.append(AttentionBackendEnum.TRITON_ATTN) backends.append(AttentionBackendEnum.TRITON_ATTN)
backends.append(AttentionBackendEnum.TURBOQUANT)
return backends return backends
......
...@@ -507,8 +507,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -507,8 +507,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
# max_query_len == max_seq_len means no request has prior cached KV. # max_query_len == max_seq_len means no request has prior cached KV.
# Both are Python ints — no GPU sync. # Both are Python ints — no GPU sync.
if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len: if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len:
output = torch.empty(N, Hq, D, device=query.device, dtype=query.dtype) return flash_attn_varlen_func(
flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -518,9 +517,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -518,9 +517,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
max_seqlen_k=attn_metadata.max_query_len, max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
out=output,
) )
return output
# Continuation or no flash_attn: per-request attention. # Continuation or no flash_attn: per-request attention.
# For continuation chunks (seq_len > q_len), we must attend to # For continuation chunks (seq_len > q_len), we must attend to
...@@ -557,10 +554,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -557,10 +554,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
if q_len == seq_len: if q_len == seq_len:
# First-chunk prefill: all K/V are in the current batch. # First-chunk prefill: all K/V are in the current batch.
if _HAS_FLASH_ATTN: if _HAS_FLASH_ATTN:
out = torch.empty_like(q_seq)
_cu_2[1] = q_len _cu_2[1] = q_len
cu = _cu_2 cu = _cu_2
flash_attn_varlen_func( out = flash_attn_varlen_func(
q=q_seq, q=q_seq,
k=k_seq, k=k_seq,
v=v_seq, v=v_seq,
...@@ -570,7 +566,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -570,7 +566,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
max_seqlen_k=q_len, max_seqlen_k=q_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
out=out,
) )
else: else:
q_t = q_seq.transpose(0, 1).contiguous() q_t = q_seq.transpose(0, 1).contiguous()
...@@ -733,10 +728,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -733,10 +728,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
# Attention: q_len queries attending to seq_len K/V with causal mask # Attention: q_len queries attending to seq_len K/V with causal mask
if _HAS_FLASH_ATTN: if _HAS_FLASH_ATTN:
output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype)
cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32) cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32) cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
flash_attn_varlen_func( return flash_attn_varlen_func(
q=query, q=query,
k=k_full, k=k_full,
v=v_full, v=v_full,
...@@ -746,9 +740,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): ...@@ -746,9 +740,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
max_seqlen_k=seq_len, max_seqlen_k=seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
out=output,
) )
return output
else: else:
# SDPA fallback: expand KV for GQA, build causal mask # SDPA fallback: expand KV for GQA, build causal mask
q_t = query.transpose(0, 1).unsqueeze(0) # (1, Hq, q_len, D) q_t = query.transpose(0, 1).unsqueeze(0) # (1, Hq, q_len, D)
......
...@@ -143,12 +143,12 @@ def _tq_decode_stage1( ...@@ -143,12 +143,12 @@ def _tq_decode_stage1(
Block_table_ptr + bt_base + page_idx, Block_table_ptr + bt_base + page_idx,
mask=kv_mask, mask=kv_mask,
other=0, other=0,
) ).to(tl.int64)
slot_bases = ( slot_bases = (
block_nums * stride_cache_block block_nums * stride_cache_block
+ page_off * stride_cache_pos + page_off.to(tl.int64) * stride_cache_pos
+ kv_head * stride_cache_head + tl.cast(kv_head, tl.int64) * stride_cache_head
) )
# ============================================================ # ============================================================
...@@ -356,11 +356,11 @@ def _tq_full_dequant_kv( ...@@ -356,11 +356,11 @@ def _tq_full_dequant_kv(
page_idx = pos // BLOCK_SIZE page_idx = pos // BLOCK_SIZE
page_off = pos % BLOCK_SIZE page_off = pos % BLOCK_SIZE
block_num = tl.load(Block_table_ptr + bid * stride_bt_b + page_idx) block_num = tl.load(Block_table_ptr + bid * stride_bt_b + page_idx).to(tl.int64)
slot_base = ( slot_base = (
block_num * stride_cache_block block_num * stride_cache_block
+ page_off * stride_cache_pos + tl.cast(page_off, tl.int64) * stride_cache_pos
+ hid * stride_cache_head + tl.cast(hid, tl.int64) * stride_cache_head
) )
d_offs = tl.arange(0, BLOCK_D) d_offs = tl.arange(0, BLOCK_D)
......
...@@ -174,10 +174,13 @@ def _tq_fused_store_fp8( ...@@ -174,10 +174,13 @@ def _tq_fused_store_fp8(
slot = tl.load(Slot_mapping_ptr + token_idx) slot = tl.load(Slot_mapping_ptr + token_idx)
if slot < 0: if slot < 0:
return return
blk = slot // BLOCK_SIZE blk = (slot // BLOCK_SIZE).to(tl.int64)
off = slot % BLOCK_SIZE off = (slot % BLOCK_SIZE).to(tl.int64)
head_idx_i64 = tl.cast(head_idx, tl.int64)
slot_base = ( slot_base = (
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head blk * stride_cache_block
+ off * stride_cache_pos
+ head_idx_i64 * stride_cache_head
) )
base = pid * D base = pid * D
...@@ -259,10 +262,13 @@ def _tq_fused_store_mse( ...@@ -259,10 +262,13 @@ def _tq_fused_store_mse(
slot = tl.load(Slot_mapping_ptr + token_idx) slot = tl.load(Slot_mapping_ptr + token_idx)
if slot < 0: if slot < 0:
return return
blk = slot // BLOCK_SIZE blk = (slot // BLOCK_SIZE).to(tl.int64)
off = slot % BLOCK_SIZE off = (slot % BLOCK_SIZE).to(tl.int64)
head_idx_i64 = tl.cast(head_idx, tl.int64)
slot_base = ( slot_base = (
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head blk * stride_cache_block
+ off * stride_cache_pos
+ head_idx_i64 * stride_cache_head
) )
base = pid * D base = pid * D
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment