Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6aa5b18e
Unverified
Commit
6aa5b18e
authored
Jan 06, 2026
by
Isotr0py
Committed by
GitHub
Jan 06, 2026
Browse files
[v1] Add encoder-only/cross attention support to Triton Attention backend (#31406)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
911d38ed
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
627 additions
and
14 deletions
+627
-14
tests/kernels/attention/test_triton_prefill_attention.py
tests/kernels/attention/test_triton_prefill_attention.py
+225
-0
tests/models/multimodal/generation/test_whisper.py
tests/models/multimodal/generation/test_whisper.py
+1
-1
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+57
-0
vllm/attention/ops/triton_prefill_attention.py
vllm/attention/ops/triton_prefill_attention.py
+271
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+0
-9
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+73
-4
No files found.
tests/kernels/attention/test_triton_prefill_attention.py
0 → 100644
View file @
6aa5b18e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.attention.ops.triton_prefill_attention
import
context_attention_fwd
def
ref_masked_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
is_causal
:
bool
=
True
,
sliding_window_q
:
int
|
None
=
None
,
sliding_window_k
:
int
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Reference implementation using PyTorch SDPA."""
# q, k, v: [total_tokens, num_heads, head_dim]
# SDPA expects [batch, num_heads, seq_len, head_dim]
total_tokens
=
q
.
shape
[
0
]
# Add batch dimension and transpose
q
=
q
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# [1, num_heads, total_tokens, head_dim]
k
=
k
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# [1, num_heads, total_tokens, head_dim]
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# [1, num_heads, total_tokens, head_dim]
# Create attention mask if needed
attn_mask
=
None
use_causal
=
is_causal
# If we have sliding window or need custom masking, create explicit mask
sliding_window_q
=
sliding_window_q
if
sliding_window_q
is
not
None
else
0
sliding_window_k
=
sliding_window_k
if
sliding_window_k
is
not
None
else
0
if
(
sliding_window_q
>
0
)
or
(
sliding_window_k
>
0
):
# Position indices
pos_q
=
torch
.
arange
(
total_tokens
,
device
=
q
.
device
).
unsqueeze
(
1
)
pos_k
=
torch
.
arange
(
total_tokens
,
device
=
q
.
device
).
unsqueeze
(
0
)
# Start with valid mask (False = no masking)
mask
=
torch
.
ones
(
(
total_tokens
,
total_tokens
),
dtype
=
torch
.
bool
,
device
=
q
.
device
)
# Apply causal mask
if
is_causal
:
mask
=
mask
&
(
pos_q
>=
pos_k
)
# Apply sliding window masks
sliding_window_mask
=
torch
.
ones_like
(
mask
)
if
sliding_window_q
>
0
:
sliding_window_mask
&=
pos_q
-
pos_k
<=
sliding_window_q
if
sliding_window_k
>
0
:
sliding_window_mask
&=
pos_k
-
pos_q
<=
sliding_window_k
mask
=
mask
&
sliding_window_mask
attn_mask
=
torch
.
where
(
mask
,
0.0
,
float
(
"-inf"
)).
to
(
q
.
dtype
)
use_causal
=
False
# Don't use is_causal when providing explicit mask
# Use SDPA
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
is_causal
=
use_causal
,
dropout_p
=
0.0
)
# Convert back to original shape: [total_tokens, num_heads, head_dim]
output
=
output
.
transpose
(
1
,
2
).
squeeze
(
0
)
return
output
@
pytest
.
mark
.
parametrize
(
"B"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"is_causal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
def
test_context_attention
(
B
:
int
,
max_seq_len
:
int
,
H_Q
:
int
,
H_KV
:
int
,
D
:
int
,
is_causal
:
bool
,
dtype
:
torch
.
dtype
,
):
"""Test basic context attention without sliding window."""
torch
.
manual_seed
(
42
)
# Generate random sequence lengths for each batch
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
"cuda"
)
total_tokens
=
seq_lens
.
sum
().
item
()
# Create batch start locations
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
# Create input tensors
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros_like
(
q
)
# Call Triton kernel
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
seq_lens
,
max_seq_len
,
is_causal
=
is_causal
,
sliding_window_q
=
None
,
sliding_window_k
=
None
,
)
# Compute reference output for each sequence in batch
o_ref
=
torch
.
zeros_like
(
q
)
for
i
in
range
(
B
):
start
=
b_start_loc
[
i
].
item
()
end
=
start
+
seq_lens
[
i
].
item
()
q_seq
=
q
[
start
:
end
]
k_seq
=
k
[
start
:
end
]
v_seq
=
v
[
start
:
end
]
# Expand KV heads if using GQA
if
H_Q
!=
H_KV
:
kv_group_num
=
H_Q
//
H_KV
k_seq
=
k_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
v_seq
=
v_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
o_ref
[
start
:
end
]
=
ref_masked_attention
(
q_seq
,
k_seq
,
v_seq
,
is_causal
=
is_causal
,
sliding_window_q
=
None
,
sliding_window_k
=
None
,
)
# Compare outputs
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"B"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[(
32
,
32
),
(
32
,
0
),
(
0
,
32
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
def
test_context_attention_sliding_window
(
B
:
int
,
max_seq_len
:
int
,
H_Q
:
int
,
H_KV
:
int
,
D
:
int
,
sliding_window
:
tuple
[
int
,
int
],
dtype
:
torch
.
dtype
,
):
sliding_window_q
,
sliding_window_k
=
sliding_window
"""Test context attention with sliding window."""
torch
.
manual_seed
(
42
)
# Generate random sequence lengths for each batch
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
"cuda"
)
total_tokens
=
seq_lens
.
sum
().
item
()
# Create batch start locations
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
# Create input tensors
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros_like
(
q
)
# Call Triton kernel
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
seq_lens
,
max_seq_len
,
is_causal
=
False
,
sliding_window_q
=
sliding_window_q
,
sliding_window_k
=
sliding_window_k
,
)
# Compute reference output for each sequence in batch
o_ref
=
torch
.
zeros_like
(
q
)
for
i
in
range
(
B
):
start
=
b_start_loc
[
i
].
item
()
end
=
start
+
seq_lens
[
i
].
item
()
q_seq
=
q
[
start
:
end
]
k_seq
=
k
[
start
:
end
]
v_seq
=
v
[
start
:
end
]
# Expand KV heads if using GQA
if
H_Q
!=
H_KV
:
kv_group_num
=
H_Q
//
H_KV
k_seq
=
k_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
v_seq
=
v_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
o_ref
[
start
:
end
]
=
ref_masked_attention
(
q_seq
,
k_seq
,
v_seq
,
is_causal
=
False
,
sliding_window_q
=
sliding_window_q
if
sliding_window_q
>
0
else
None
,
sliding_window_k
=
sliding_window_k
if
sliding_window_k
>
0
else
None
,
)
# Compare outputs
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
2e-2
,
atol
=
2e-2
)
tests/models/multimodal/generation/test_whisper.py
View file @
6aa5b18e
...
...
@@ -114,7 +114,7 @@ def check_model_available(model: str) -> None:
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
cpu_model
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"openai/whisper-large-v3-turbo"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
,
"float"
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
create_new_process_for_each_test
(
"spawn"
)
...
...
tests/v1/attention/test_attention_backends.py
View file @
6aa5b18e
...
...
@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
create_vllm_config
,
try_get_attention_backend
,
)
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config
import
ModelConfig
from
vllm.platforms
import
current_platform
...
...
@@ -83,6 +84,13 @@ BATCH_SPECS = {
),
"single_decode"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
1
]),
"single_prefill"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
64
]),
# encoder-only
"small_encoder_prefill"
:
BatchSpec
(
seq_lens
=
[
32
,
64
,
128
,
256
],
query_lens
=
[
32
,
64
,
128
,
256
]
),
"medium_encoder_prefill"
:
BatchSpec
(
seq_lens
=
[
256
,
512
,
1024
,
2048
],
query_lens
=
[
256
,
512
,
1024
,
2048
]
),
}
...
...
@@ -209,6 +217,7 @@ def run_attention_backend(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
sliding_window
:
int
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Run attention computation using the specified backend's AttentionImpl."""
...
...
@@ -276,6 +285,7 @@ def run_attention_backend(
num_kv_heads
=
num_kv_heads
,
alibi_slopes
=
None
,
sliding_window
=
sliding_window
,
attn_type
=
attn_type
,
kv_cache_dtype
=
"auto"
,
)
...
...
@@ -299,6 +309,7 @@ def _test_backend_correctness(
backend_to_test
:
list
[
AttentionBackendEnum
|
str
],
mask_mod
,
*
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
block_size
:
int
=
16
,
atol
:
float
=
1e-2
,
rtol
:
float
=
1e-2
,
...
...
@@ -436,6 +447,9 @@ def _test_backend_correctness(
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
vllm_config
.
cache_config
.
block_size
,
device
)
if
attn_type
==
AttentionType
.
ENCODER_ONLY
:
# For encoder-only, all tokens are prefill tokens
common_attn_metadata
.
causal
=
False
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache
=
create_and_prepopulate_kv_cache
(
...
...
@@ -491,6 +505,7 @@ def _test_backend_correctness(
value_vllm
,
kv_cache_for_backend
,
sliding_window
=
sliding_window
,
attn_type
=
attn_type
,
)
finally
:
if
reset_kv_cache_layout
:
...
...
@@ -676,3 +691,45 @@ def test_sliding_window_backend_correctness(
block_size
=
128
,
tensor_parallel_size
=
tensor_parallel_size
,
)
@
pytest
.
mark
.
parametrize
(
"batch_spec_name"
,
[
"small_encoder_prefill"
,
"medium_encoder_prefill"
,
],
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"google/embeddinggemma-300m"
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
,
2
])
def
test_sliding_window_encoder_backend_correctness
(
batch_spec_name
:
str
,
model
:
str
,
tensor_parallel_size
:
int
):
"""Test backend's correctness with sliding window attention."""
def
bidi_sliding_window_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
kv_idx
:
torch
.
Tensor
,
*
,
context_len
:
int
,
sliding_window
:
int
,
):
return
torch
.
abs
(
q_idx
+
context_len
-
kv_idx
)
<
sliding_window
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
model_config
=
ModelConfig
(
model
=
model
,
max_model_len
=
max
(
batch_spec
.
seq_lens
))
sliding_window
=
model_config
.
get_sliding_window
()
sliding_window_mask_mod_fn
=
partial
(
bidi_sliding_window_mask_mod
,
sliding_window
=
sliding_window
)
_test_backend_correctness
(
batch_spec
,
model
,
SLIDING_WINDOW_BACKENDS_TO_TEST
,
sliding_window_mask_mod_fn
,
attn_type
=
AttentionType
.
ENCODER_ONLY
,
tensor_parallel_size
=
tensor_parallel_size
,
)
vllm/attention/ops/triton_prefill_attention.py
0 → 100644
View file @
6aa5b18e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/sgl-project/sglang/blob/97cb762bb65ebf05025eb342de03c184660427a3/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# Changes:
# - Add support for sliding window attention
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for prefill.
It supports page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import
torch
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
Out
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
SLIDING_WINDOW_Q
:
tl
.
constexpr
,
SLIDING_WINDOW_K
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
off_k
=
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
off_v
=
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
mask_d
=
offs_d
<
Lk
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
,
1
,
0
)
# Calculate the end position for attention computation
end_n
=
cur_batch_seq_len
# Apply causal attention pruning and sliding window attention pruning
end_n
=
tl
.
minimum
(
end_n
,
(
start_m
+
1
)
*
BLOCK_M
)
if
IS_CAUSAL
else
end_n
# Calculate the start position for backward sliding window
start_n_limit
=
0
end_n_limit
=
block_mask
*
end_n
for
start_n
in
range
(
start_n_limit
,
end_n_limit
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
# Apply attention mask (causal + bidirectional sliding window)
# Position indices in the sequence
pos_q
=
offs_m
[:,
None
]
# Query positions [BLOCK_M, 1]
pos_k
=
start_n
+
offs_n
[
None
,
:]
# Key positions [1, BLOCK_N]
# Valid sequence mask
mask
=
pos_k
<
cur_batch_seq_len
# Causal mask
if
IS_CAUSAL
:
mask
&=
pos_q
>=
pos_k
# Bidirectional sliding window masks
sliding_mask_q
=
(
pos_q
-
pos_k
<=
SLIDING_WINDOW_Q
if
SLIDING_WINDOW_Q
>
0
else
None
)
sliding_mask_k
=
(
pos_k
-
pos_q
<=
SLIDING_WINDOW_K
if
SLIDING_WINDOW_K
>
0
else
None
)
if
sliding_mask_q
is
not
None
:
mask
&=
sliding_mask_q
if
sliding_mask_k
is
not
None
:
mask
&=
sliding_mask_k
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
where
(
mask
,
0
,
float
(
"-inf"
))
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_ij_valid_mask
=
m_ij
>
float
(
"-inf"
)
m_ij_masked
=
tl
.
where
(
m_ij_valid_mask
,
m_ij
,
0.0
)
# -- compute p and l_ij --
p
=
tl
.
exp
(
qk
-
m_ij_masked
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
m_i_new_mask
=
m_i_new
>
float
(
"-inf"
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
# mask alpha and beta for sliding window
alpha
=
tl
.
where
(
m_i_new_mask
,
alpha
,
1.0
)
beta
=
tl
.
where
(
m_i_new_mask
,
beta
,
0.0
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
# For sliding window there's a chance the l_i_new is 0 due to masking
# the entire row. We need to set l_i_new 1 to avoid zero division
l_i_new_mask
=
(
l_i_new
!=
0.0
)
&
(
m_i_new_mask
>
float
(
"-inf"
))
l_i_new_safe
=
tl
.
where
(
l_i_new_mask
,
l_i_new
,
1.0
)
p_scale
=
beta
/
l_i_new_safe
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new_safe
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:])
)
def
get_block_size
(
dtype
:
torch
.
dtype
)
->
int
:
if
dtype
==
torch
.
float32
:
return
32
elif
(
current_platform
.
is_cuda_alike
()
)
and
current_platform
.
get_device_capability
().
major
>
8
:
return
128
else
:
return
64
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
,
is_causal
=
True
,
sliding_window_q
=
None
,
sliding_window_k
=
None
,
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
BLOCK
=
get_block_size
(
q
.
dtype
)
Lq
,
Lk
,
_
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
num_warps
=
4
if
Lk
<=
64
else
8
sliding_window_q
=
sliding_window_q
if
sliding_window_q
is
not
None
else
0
sliding_window_k
=
sliding_window_k
if
sliding_window_k
is
not
None
else
0
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
b_start_loc
,
b_seq_len
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
kv_group_num
=
kv_group_num
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
),
BLOCK_N
=
BLOCK
,
IS_CAUSAL
=
is_causal
,
SLIDING_WINDOW_Q
=
sliding_window_q
,
SLIDING_WINDOW_K
=
sliding_window_k
,
num_warps
=
num_warps
,
num_stages
=
1
,
Lk
=
Lk
,
)
vllm/platforms/rocm.py
View file @
6aa5b18e
...
...
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Optional
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.logger
import
init_logger
from
vllm.utils.torch_utils
import
cuda_device_count_stateless
...
...
@@ -289,14 +288,6 @@ class RocmPlatform(Platform):
logger
.
info
(
"Using Aiter Flash Attention backend."
)
return
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
()
# Priority 5: If model is Encoder-only self-attention type
if
(
attn_selector_config
.
attn_type
is
not
None
and
attn_selector_config
.
attn_type
==
AttentionType
.
ENCODER_ONLY
):
logger
.
info
(
"Using FlexAttention backend."
)
return
AttentionBackendEnum
.
FLEX_ATTENTION
.
get_path
()
# Default: Triton Unified Attention
logger
.
info
(
"Using Triton Attention backend."
)
return
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
()
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
6aa5b18e
...
...
@@ -13,6 +13,7 @@ from vllm.attention.backends.abstract import (
AttentionType
,
MultipleOf
,
)
from
vllm.attention.ops.triton_prefill_attention
import
context_attention_fwd
from
vllm.attention.ops.triton_reshape_and_cache_flash
import
(
triton_reshape_and_cache_flash
,
)
...
...
@@ -309,6 +310,16 @@ class TritonAttentionBackend(AttentionBackend):
def
supports_sink
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
"""TritonAttention supports all attention types."""
return
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
,
AttentionType
.
ENCODER_DECODER
,
)
@
classmethod
def
supports_compute_capability
(
cls
,
capability
:
DeviceCapability
)
->
bool
:
return
True
...
...
@@ -341,6 +352,8 @@ class TritonAttentionImpl(AttentionImpl):
self
.
alibi_slopes
=
alibi_slopes
if
sliding_window
is
None
:
self
.
sliding_window
=
(
-
1
,
-
1
)
elif
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
):
self
.
sliding_window
=
(
sliding_window
-
1
,
sliding_window
-
1
)
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
...
...
@@ -352,10 +365,6 @@ class TritonAttentionImpl(AttentionImpl):
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
attn_type
not
in
[
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_DECODER
]:
raise
NotImplementedError
(
"Encoder self-attention is not implemented for TritonAttentionImpl"
)
self
.
attn_type
=
attn_type
self
.
fp8_dtype
=
current_platform
.
fp8_dtype
()
...
...
@@ -417,6 +426,21 @@ class TritonAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if
self
.
attn_type
in
(
AttentionType
.
ENCODER_ONLY
,
AttentionType
.
ENCODER
):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
self
.
_forward_encoder_attention
(
query
[:
num_actual_tokens
],
key
[:
num_actual_tokens
],
value
[:
num_actual_tokens
],
output
[:
num_actual_tokens
],
attn_metadata
,
layer
,
)
# For decoder and cross-attention, use KV cache as before
key_cache
,
value_cache
=
kv_cache
.
unbind
(
1
)
if
(
...
...
@@ -495,3 +519,48 @@ class TritonAttentionImpl(AttentionImpl):
)
return
output
def
_forward_encoder_attention
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
attn_metadata
:
TritonAttentionMetadata
,
layer
:
torch
.
nn
.
Module
,
)
->
torch
.
Tensor
:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"quantization is not supported for encoder attention"
)
# Use encoder-specific metadata for sequence information
query_start_loc
=
attn_metadata
.
query_start_loc
seq_lens
=
attn_metadata
.
seq_lens
max_query_len
=
attn_metadata
.
max_query_len
# Call flash attention directly on Q, K, V tensors
context_attention_fwd
(
q
=
query
,
k
=
key
,
v
=
value
,
o
=
output
,
b_start_loc
=
query_start_loc
,
b_seq_len
=
seq_lens
,
max_input_len
=
max_query_len
,
is_causal
=
False
,
# Encoder attention is bidirectional
sliding_window_q
=
self
.
sliding_window
[
0
],
sliding_window_k
=
self
.
sliding_window
[
1
],
)
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment