Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d097697
Unverified
Commit
6d097697
authored
Apr 21, 2026
by
Micah Williamson
Committed by
GitHub
Apr 22, 2026
Browse files
[ROCm] Support non-causal attention in ROCM_ATTN (#40176)
Signed-off-by:
Micah Williamson
<
micah.williamson@amd.com
>
parent
4506319a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
13 deletions
+41
-13
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
+6
-2
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+11
-3
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
+2
-0
vllm/v1/attention/ops/prefix_prefill.py
vllm/v1/attention/ops/prefix_prefill.py
+22
-8
No files found.
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
View file @
6d097697
...
...
@@ -13,10 +13,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
AttentionLayer
,
AttentionType
,
MultipleOf
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.rocm_attn
import
(
RocmAttentionBackend
,
RocmAttentionImpl
,
RocmAttentionMetadata
,
RocmAttentionMetadataBuilder
,
)
...
...
@@ -53,6 +53,10 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def
supports_sink
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_non_causal
(
cls
)
->
bool
:
return
False
forward_includes_kv_cache_update
:
bool
=
False
@
staticmethod
...
...
@@ -140,7 +144,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
Flash
AttentionMetadata
,
attn_metadata
:
Rocm
AttentionMetadata
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
6d097697
...
...
@@ -27,7 +27,6 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata
,
MultipleOf
,
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.ops.chunked_prefill_paged_decode
import
(
chunked_prefill_paged_decode
,
)
...
...
@@ -69,6 +68,9 @@ class RocmAttentionMetadata:
scheduler_metadata
:
torch
.
Tensor
|
None
=
None
prefix_scheduler_metadata
:
torch
.
Tensor
|
None
=
None
# DFlash drafting sets this to False via CommonAttentionMetadata.
causal
:
bool
=
True
class
RocmAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
RocmAttentionMetadata
]):
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
ALWAYS
...
...
@@ -154,6 +156,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
causal
=
common_attn_metadata
.
causal
,
)
return
attn_metadata
...
...
@@ -200,6 +203,10 @@ class RocmAttentionBackend(AttentionBackend):
# kernel, which is less efficient than the proper triton backends.
return
False
@
classmethod
def
supports_non_causal
(
cls
)
->
bool
:
return
True
forward_includes_kv_cache_update
:
bool
=
False
@
staticmethod
...
...
@@ -301,7 +308,7 @@ class RocmAttentionImpl(AttentionImpl):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
attn_metadata
:
Flash
AttentionMetadata
,
attn_metadata
:
Rocm
AttentionMetadata
,
layer
:
torch
.
nn
.
Module
,
)
->
torch
.
Tensor
:
"""Forward pass for encoder attention without KV cache.
...
...
@@ -350,7 +357,7 @@ class RocmAttentionImpl(AttentionImpl):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
Flash
AttentionMetadata
,
attn_metadata
:
Rocm
AttentionMetadata
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -438,6 +445,7 @@ class RocmAttentionImpl(AttentionImpl):
sm_scale
=
self
.
scale
,
output_scale
=
output_scale
,
sinks
=
self
.
sinks
,
causal
=
attn_metadata
.
causal
,
)
return
output
...
...
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
View file @
6d097697
...
...
@@ -269,6 +269,7 @@ def chunked_prefill_paged_decode(
# Optional tensor for sinks
sinks
=
None
,
is_block_table_ptr
:
bool
=
False
,
causal
:
bool
=
True
,
):
if
sm_scale
is
None
:
sm_scale
=
1.0
/
(
query
.
shape
[
2
]
**
0.5
)
...
...
@@ -300,6 +301,7 @@ def chunked_prefill_paged_decode(
skip_decode
=
True
,
fp8_out_scale
=
output_scale
,
sinks
=
sinks
,
causal
=
causal
,
)
block_size
=
value_cache
.
shape
[
3
]
...
...
vllm/v1/attention/ops/prefix_prefill.py
View file @
6d097697
...
...
@@ -89,6 +89,7 @@ def _fwd_kernel(
SKIP_DECODE
:
tl
.
constexpr
,
USE_SINKS
:
tl
.
constexpr
,
USE_FP8
:
tl
.
constexpr
,
CAUSAL
:
tl
.
constexpr
=
True
,
MAX_Q_LEN
:
tl
.
constexpr
=
0
,
MAX_CTX_LEN
:
tl
.
constexpr
=
0
,
FP8_MIN
:
tl
.
constexpr
=
float8_info
.
min
,
...
...
@@ -283,10 +284,17 @@ def _fwd_kernel(
# block_mask is 0 when we're already past the current query length
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_query_len
,
1
,
0
)
# compute query against itself (with causal mask)
# compute query against itself (causal among queries by default;
# CAUSAL=False for bidirectional attention over query tokens, e.g. DFlash.)
if
CAUSAL
:
key_range_upper
=
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
else
:
q_len_pad
=
(
cur_batch_query_len
+
BLOCK_N
-
1
)
//
BLOCK_N
*
BLOCK_N
key_range_upper
=
block_mask
*
q_len_pad
for
start_n
in
tl
.
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
key_range_upper
,
BLOCK_N
,
loop_unroll_factor
=
num_unroll_request
,
):
...
...
@@ -302,14 +310,17 @@ def _fwd_kernel(
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
dot
(
q
,
k
,
acc
=
qk
,
input_precision
=
IN_PRECISION
)
qk
*=
sm_scale
# apply causal mask
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
valid_kv
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_query_len
if
CAUSAL
:
attn_mask
=
valid_kv
&
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]))
else
:
attn_mask
=
valid_kv
if
SLIDING_WINDOW
>
0
:
qk
=
tl
.
where
(
offs_m
[:,
None
]
-
(
start_n
+
offs_n
[
None
,
:])
<
SLIDING_WINDOW
,
qk
,
float
(
"-inf"
),
attn_mask
=
attn_mask
&
(
offs_m
[:,
None
]
-
(
start_n
+
offs_n
[
None
,
:])
<
SLIDING_WINDOW
)
qk
=
tl
.
where
(
attn_mask
,
qk
,
float
(
"-inf"
))
# compute running maximum
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
axis
=
1
))
...
...
@@ -656,6 +667,7 @@ def context_attention_fwd(
fp8_out_scale
=
None
,
sinks
=
None
,
is_block_table_ptr
:
bool
=
False
,
causal
:
bool
=
True
,
):
q_dtype_is_f32
=
q
.
dtype
is
torch
.
float32
...
...
@@ -722,6 +734,7 @@ def context_attention_fwd(
processed_b_loc
=
b_loc
.
to
(
torch
.
int32
)
if
alibi_slopes
is
not
None
:
assert
causal
,
"Non-causal prefix attention is not supported with alibi"
assert
sinks
is
None
,
"Sinks arg is not supported with alibi"
assert
fp8_out_scale
is
None
,
"FP8 output not supported with alibi"
# need to reduce num. blocks when using fp32
...
...
@@ -859,6 +872,7 @@ def context_attention_fwd(
num_warps
=
4
,
num_stages
=
1
,
USE_SINKS
=
sinks
is
not
None
,
CAUSAL
=
causal
,
**
extra_kargs
,
)
return
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment