Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6ef1efd5
Unverified
Commit
6ef1efd5
authored
Apr 17, 2026
by
aditi-amd
Committed by
GitHub
Apr 17, 2026
Browse files
[ROCm] Fix TurboQuant on ROCm: backend routing, flash-attn compat, int64 overflow (#39953)
Signed-off-by:
aditi
<
aditi.rana@amd.com
>
parent
58da4ee0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
23 deletions
+22
-23
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-0
vllm/v1/attention/backends/turboquant_attn.py
vllm/v1/attention/backends/turboquant_attn.py
+3
-11
vllm/v1/attention/ops/triton_turboquant_decode.py
vllm/v1/attention/ops/triton_turboquant_decode.py
+6
-6
vllm/v1/attention/ops/triton_turboquant_store.py
vllm/v1/attention/ops/triton_turboquant_store.py
+12
-6
No files found.
vllm/platforms/rocm.py
View file @
6ef1efd5
...
@@ -382,6 +382,7 @@ def _get_backend_priorities(
...
@@ -382,6 +382,7 @@ def _get_backend_priorities(
if
is_aiter_found_and_supported
():
if
is_aiter_found_and_supported
():
backends
.
append
(
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
)
backends
.
append
(
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
)
backends
.
append
(
AttentionBackendEnum
.
TRITON_ATTN
)
backends
.
append
(
AttentionBackendEnum
.
TRITON_ATTN
)
backends
.
append
(
AttentionBackendEnum
.
TURBOQUANT
)
return
backends
return
backends
...
...
vllm/v1/attention/backends/turboquant_attn.py
View file @
6ef1efd5
...
@@ -507,8 +507,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
...
@@ -507,8 +507,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
# max_query_len == max_seq_len means no request has prior cached KV.
# max_query_len == max_seq_len means no request has prior cached KV.
# Both are Python ints — no GPU sync.
# Both are Python ints — no GPU sync.
if
_HAS_FLASH_ATTN
and
attn_metadata
.
max_query_len
==
attn_metadata
.
max_seq_len
:
if
_HAS_FLASH_ATTN
and
attn_metadata
.
max_query_len
==
attn_metadata
.
max_seq_len
:
output
=
torch
.
empty
(
N
,
Hq
,
D
,
device
=
query
.
device
,
dtype
=
query
.
dtype
)
return
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
@@ -518,9 +517,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
...
@@ -518,9 +517,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
max_seqlen_k
=
attn_metadata
.
max_query_len
,
max_seqlen_k
=
attn_metadata
.
max_query_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
out
=
output
,
)
)
return
output
# Continuation or no flash_attn: per-request attention.
# Continuation or no flash_attn: per-request attention.
# For continuation chunks (seq_len > q_len), we must attend to
# For continuation chunks (seq_len > q_len), we must attend to
...
@@ -557,10 +554,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
...
@@ -557,10 +554,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
if
q_len
==
seq_len
:
if
q_len
==
seq_len
:
# First-chunk prefill: all K/V are in the current batch.
# First-chunk prefill: all K/V are in the current batch.
if
_HAS_FLASH_ATTN
:
if
_HAS_FLASH_ATTN
:
out
=
torch
.
empty_like
(
q_seq
)
_cu_2
[
1
]
=
q_len
_cu_2
[
1
]
=
q_len
cu
=
_cu_2
cu
=
_cu_2
flash_attn_varlen_func
(
out
=
flash_attn_varlen_func
(
q
=
q_seq
,
q
=
q_seq
,
k
=
k_seq
,
k
=
k_seq
,
v
=
v_seq
,
v
=
v_seq
,
...
@@ -570,7 +566,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
...
@@ -570,7 +566,6 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
max_seqlen_k
=
q_len
,
max_seqlen_k
=
q_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
out
=
out
,
)
)
else
:
else
:
q_t
=
q_seq
.
transpose
(
0
,
1
).
contiguous
()
q_t
=
q_seq
.
transpose
(
0
,
1
).
contiguous
()
...
@@ -733,10 +728,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
...
@@ -733,10 +728,9 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
# Attention: q_len queries attending to seq_len K/V with causal mask
# Attention: q_len queries attending to seq_len K/V with causal mask
if
_HAS_FLASH_ATTN
:
if
_HAS_FLASH_ATTN
:
output
=
torch
.
empty
(
q_len
,
Hq
,
D
,
device
=
device
,
dtype
=
query
.
dtype
)
cu_seqlens_q
=
torch
.
tensor
([
0
,
q_len
],
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
torch
.
tensor
([
0
,
q_len
],
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
tensor
([
0
,
seq_len
],
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
tensor
([
0
,
seq_len
],
device
=
device
,
dtype
=
torch
.
int32
)
flash_attn_varlen_func
(
return
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
k_full
,
k
=
k_full
,
v
=
v_full
,
v
=
v_full
,
...
@@ -746,9 +740,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
...
@@ -746,9 +740,7 @@ class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
max_seqlen_k
=
seq_len
,
max_seqlen_k
=
seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
out
=
output
,
)
)
return
output
else
:
else
:
# SDPA fallback: expand KV for GQA, build causal mask
# SDPA fallback: expand KV for GQA, build causal mask
q_t
=
query
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
# (1, Hq, q_len, D)
q_t
=
query
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
# (1, Hq, q_len, D)
...
...
vllm/v1/attention/ops/triton_turboquant_decode.py
View file @
6ef1efd5
...
@@ -143,12 +143,12 @@ def _tq_decode_stage1(
...
@@ -143,12 +143,12 @@ def _tq_decode_stage1(
Block_table_ptr
+
bt_base
+
page_idx
,
Block_table_ptr
+
bt_base
+
page_idx
,
mask
=
kv_mask
,
mask
=
kv_mask
,
other
=
0
,
other
=
0
,
)
)
.
to
(
tl
.
int64
)
slot_bases
=
(
slot_bases
=
(
block_nums
*
stride_cache_block
block_nums
*
stride_cache_block
+
page_off
*
stride_cache_pos
+
page_off
.
to
(
tl
.
int64
)
*
stride_cache_pos
+
kv_head
*
stride_cache_head
+
tl
.
cast
(
kv_head
,
tl
.
int64
)
*
stride_cache_head
)
)
# ============================================================
# ============================================================
...
@@ -356,11 +356,11 @@ def _tq_full_dequant_kv(
...
@@ -356,11 +356,11 @@ def _tq_full_dequant_kv(
page_idx
=
pos
//
BLOCK_SIZE
page_idx
=
pos
//
BLOCK_SIZE
page_off
=
pos
%
BLOCK_SIZE
page_off
=
pos
%
BLOCK_SIZE
block_num
=
tl
.
load
(
Block_table_ptr
+
bid
*
stride_bt_b
+
page_idx
)
block_num
=
tl
.
load
(
Block_table_ptr
+
bid
*
stride_bt_b
+
page_idx
)
.
to
(
tl
.
int64
)
slot_base
=
(
slot_base
=
(
block_num
*
stride_cache_block
block_num
*
stride_cache_block
+
page_off
*
stride_cache_pos
+
tl
.
cast
(
page_off
,
tl
.
int64
)
*
stride_cache_pos
+
hid
*
stride_cache_head
+
tl
.
cast
(
hid
,
tl
.
int64
)
*
stride_cache_head
)
)
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
...
...
vllm/v1/attention/ops/triton_turboquant_store.py
View file @
6ef1efd5
...
@@ -174,10 +174,13 @@ def _tq_fused_store_fp8(
...
@@ -174,10 +174,13 @@ def _tq_fused_store_fp8(
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
if
slot
<
0
:
if
slot
<
0
:
return
return
blk
=
slot
//
BLOCK_SIZE
blk
=
(
slot
//
BLOCK_SIZE
).
to
(
tl
.
int64
)
off
=
slot
%
BLOCK_SIZE
off
=
(
slot
%
BLOCK_SIZE
).
to
(
tl
.
int64
)
head_idx_i64
=
tl
.
cast
(
head_idx
,
tl
.
int64
)
slot_base
=
(
slot_base
=
(
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx
*
stride_cache_head
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx_i64
*
stride_cache_head
)
)
base
=
pid
*
D
base
=
pid
*
D
...
@@ -259,10 +262,13 @@ def _tq_fused_store_mse(
...
@@ -259,10 +262,13 @@ def _tq_fused_store_mse(
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
if
slot
<
0
:
if
slot
<
0
:
return
return
blk
=
slot
//
BLOCK_SIZE
blk
=
(
slot
//
BLOCK_SIZE
).
to
(
tl
.
int64
)
off
=
slot
%
BLOCK_SIZE
off
=
(
slot
%
BLOCK_SIZE
).
to
(
tl
.
int64
)
head_idx_i64
=
tl
.
cast
(
head_idx
,
tl
.
int64
)
slot_base
=
(
slot_base
=
(
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx
*
stride_cache_head
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx_i64
*
stride_cache_head
)
)
base
=
pid
*
D
base
=
pid
*
D
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment