Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6bd1dd9d
Unverified
Commit
6bd1dd9d
authored
Mar 06, 2025
by
Thomas Parnell
Committed by
GitHub
Mar 06, 2025
Browse files
[Kernel] [V1] Improved performance for V1 Triton (ROCm) backend (#14152)
parent
4f27044a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
398 additions
and
77 deletions
+398
-77
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+76
-59
vllm/attention/ops/chunked_prefill_paged_decode.py
vllm/attention/ops/chunked_prefill_paged_decode.py
+289
-0
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+13
-1
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+20
-17
No files found.
tests/kernels/test_prefix_prefill.py
View file @
6bd1dd9d
...
...
@@ -3,6 +3,7 @@
import
math
import
random
import
time
from
collections.abc
import
Callable
import
pytest
import
torch
...
...
@@ -10,6 +11,8 @@ from xformers import ops as xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.ops.chunked_prefill_paged_decode
import
(
chunked_prefill_paged_decode
)
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
...
...
@@ -24,6 +27,8 @@ CUDA_DEVICES = [
SLIDING_WINDOW
=
[
0
,
16
,
64
,
128
,
256
,
512
,
2048
]
KV_CACHE_DTYPES
=
[
"auto"
,
"fp8"
,
"fp8_e5m2"
]
OPS
=
[
chunked_prefill_paged_decode
,
context_attention_fwd
]
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_queries_per_kv"
,
NUM_QUERIES_PER_KV
)
...
...
@@ -32,6 +37,7 @@ KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOW
)
@
pytest
.
mark
.
parametrize
(
"op"
,
OPS
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention
(
num_heads
:
int
,
...
...
@@ -41,6 +47,7 @@ def test_contexted_kv_attention(
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
device
:
str
,
op
:
Callable
,
)
->
None
:
if
'fp8'
in
kv_cache_dtype
and
not
current_platform
.
has_device_capability
(
...
...
@@ -65,6 +72,9 @@ def test_contexted_kv_attention(
block_size
=
32
max_block_per_request
=
64
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
# ensure one sequence in batch is a decode
query_lens
[
-
1
]
=
1
ctx_lens
=
[
random
.
randint
(
16
,
MAX_CTX_LEN
)
for
_
in
range
(
BS
)]
seq_lens
=
[
a
+
b
for
a
,
b
in
zip
(
query_lens
,
ctx_lens
)]
num_kv_heads
=
num_heads
//
num_queries_per_kv
...
...
@@ -144,36 +154,36 @@ def test_contexted_kv_attention(
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd
(
query
,
k
,
v
,
output
,
kv_cache_dtype
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
max_input_len
,
k_scale
,
v_scale
,
sliding_window
=
sliding_window
)
op
(
query
,
k
,
v
,
output
,
kv_cache_dtype
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
max_input_len
,
k_scale
,
v_scale
,
sliding_window
=
sliding_window
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
context_attention_fwd
(
query
,
k
,
v
,
output
,
kv_cache_dtype
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
max_input_len
,
k_scale
,
v_scale
,
sliding_window
=
sliding_window
)
op
(
query
,
k
,
v
,
output
,
kv_cache_dtype
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
max_input_len
,
k_scale
,
v_scale
,
sliding_window
=
sliding_window
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
...
...
@@ -228,7 +238,7 @@ def test_contexted_kv_attention(
end_time
=
time
.
time
()
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
output_ref
=
output_ref
.
reshape
(
output
.
shape
)
atol
=
1e-3
if
"fp8"
in
kv_cache_dtype
else
1e-
6
atol
=
1e-3
if
"fp8"
in
kv_cache_dtype
else
1e-
4
torch
.
testing
.
assert_close
(
output
,
output_ref
,
atol
=
atol
,
rtol
=
0
)
...
...
@@ -238,6 +248,7 @@ def test_contexted_kv_attention(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"op"
,
OPS
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention_alibi
(
num_heads
:
int
,
...
...
@@ -246,6 +257,7 @@ def test_contexted_kv_attention_alibi(
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
device
:
str
,
op
:
Callable
,
)
->
None
:
if
'fp8'
in
kv_cache_dtype
and
not
current_platform
.
has_device_capability
(
...
...
@@ -375,36 +387,36 @@ def test_contexted_kv_attention_alibi(
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd
(
query
,
k
,
v
,
output
,
kv_cache_dtype
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
max_input_len
,
k_scale
,
v_scale
,
alibi_slopes
=
alibi_slopes
)
op
(
query
,
k
,
v
,
output
,
kv_cache_dtype
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
max_input_len
,
k_scale
,
v_scale
,
alibi_slopes
=
alibi_slopes
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
context_attention_fwd
(
query
,
k
,
v
,
output
,
kv_cache_dtype
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
max_input_len
,
k_scale
,
v_scale
,
alibi_slopes
=
alibi_slopes
)
op
(
query
,
k
,
v
,
output
,
kv_cache_dtype
,
k_cache
,
v_cache
,
block_table
,
b_start_loc
,
b_seq_len
,
max_input_len
,
k_scale
,
v_scale
,
alibi_slopes
=
alibi_slopes
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
...
...
@@ -503,6 +515,7 @@ def test_contexted_kv_attention_alibi(
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOW
)
@
pytest
.
mark
.
parametrize
(
"op"
,
OPS
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention_f32
(
num_heads
:
int
,
...
...
@@ -512,9 +525,11 @@ def test_contexted_kv_attention_f32(
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
device
:
str
,
op
:
Callable
,
)
->
None
:
test_contexted_kv_attention
(
num_heads
,
num_queries_per_kv
,
head_size
,
sliding_window
,
dtype
,
kv_cache_dtype
,
device
)
sliding_window
,
dtype
,
kv_cache_dtype
,
device
,
op
)
@
pytest
.
mark
.
optional
...
...
@@ -524,6 +539,7 @@ def test_contexted_kv_attention_f32(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"op"
,
OPS
)
@
torch
.
inference_mode
()
def
test_contexted_kv_attention_alibi_f32
(
num_heads
:
int
,
...
...
@@ -532,6 +548,7 @@ def test_contexted_kv_attention_alibi_f32(
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
device
:
str
,
op
:
Callable
,
)
->
None
:
test_contexted_kv_attention_alibi
(
num_heads
,
num_queries_per_kv
,
head_size
,
dtype
,
kv_cache_dtype
,
device
)
dtype
,
kv_cache_dtype
,
device
,
op
)
vllm/attention/ops/chunked_prefill_paged_decode.py
0 → 100644
View file @
6bd1dd9d
# SPDX-License-Identifier: Apache-2.0
import
torch
import
triton
import
triton.language
as
tl
from
.prefix_prefill
import
context_attention_fwd
@
triton
.
jit
def
cdiv_fn
(
x
,
y
):
return
(
x
+
y
-
1
)
//
y
@
triton
.
jit
def
kernel_paged_attention_2d
(
output_ptr
,
# [num_tokens, num_query_heads, head_size]
query_ptr
,
# [num_tokens, num_query_heads, head_size]
key_cache_ptr
,
# [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr
,
# [num_blks, num_kv_heads, head_size, blk_size]
block_tables_ptr
,
# [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr
,
# [num_seqs]
alibi_slopes_ptr
,
# [num_query_heads]
scale
,
# float32
k_scale
,
# float32
v_scale
,
# float32
num_query_heads
:
tl
.
constexpr
,
# int
num_queries_per_kv
:
tl
.
constexpr
,
# int
block_table_stride
:
tl
.
constexpr
,
# int
query_stride_0
:
tl
.
constexpr
,
# int
query_stride_1
:
tl
.
constexpr
,
# int, should be equal to head_size
output_stride_0
:
tl
.
constexpr
,
# int
output_stride_1
:
tl
.
constexpr
,
# int, should be equal to head_size
BLOCK_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE_PADDED
:
tl
.
constexpr
,
# int, must be power of 2
USE_ALIBI_SLOPES
:
tl
.
constexpr
,
# bool
SLIDING_WINDOW
:
tl
.
constexpr
,
# int
x
:
tl
.
constexpr
,
# int
stride_k_cache_0
:
tl
.
constexpr
,
# int
stride_k_cache_1
:
tl
.
constexpr
,
# int
stride_k_cache_2
:
tl
.
constexpr
,
# int
stride_k_cache_3
:
tl
.
constexpr
,
# int
stride_k_cache_4
:
tl
.
constexpr
,
# int
stride_v_cache_0
:
tl
.
constexpr
,
# int
stride_v_cache_1
:
tl
.
constexpr
,
# int
stride_v_cache_2
:
tl
.
constexpr
,
# int
stride_v_cache_3
:
tl
.
constexpr
,
# int
filter_by_query_len
:
tl
.
constexpr
,
# bool
query_start_len_ptr
,
# [num_seqs+1]
):
seq_idx
=
tl
.
program_id
(
0
)
query_head_idx
=
tl
.
program_id
(
1
)
kv_head_idx
=
query_head_idx
//
num_queries_per_kv
if
filter_by_query_len
:
cur_batch_in_all_start_index
=
tl
.
load
(
query_start_len_ptr
+
seq_idx
)
cur_batch_in_all_stop_index
=
tl
.
load
(
query_start_len_ptr
+
seq_idx
+
1
)
cur_batch_query_len
=
cur_batch_in_all_stop_index
\
-
cur_batch_in_all_start_index
if
cur_batch_query_len
>
1
:
return
else
:
cur_batch_in_all_start_index
=
seq_idx
query_offset
=
(
cur_batch_in_all_start_index
*
query_stride_0
+
query_head_idx
*
query_stride_1
)
dim_mask
=
tl
.
where
(
tl
.
arange
(
0
,
HEAD_SIZE_PADDED
)
<
HEAD_SIZE
,
1
,
0
).
to
(
tl
.
int1
)
# Q : (HEAD_SIZE,)
Q
=
tl
.
load
(
query_ptr
+
query_offset
+
tl
.
arange
(
0
,
HEAD_SIZE_PADDED
),
mask
=
dim_mask
,
other
=
0.0
,
)
block_table_offset
=
seq_idx
*
block_table_stride
M
=
tl
.
full
([
1
],
float
(
"-inf"
),
dtype
=
tl
.
float32
)
L
=
tl
.
full
([
1
],
1.0
,
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
HEAD_SIZE_PADDED
],
dtype
=
tl
.
float32
)
# sequence len for this particular sequence
seq_len
=
tl
.
load
(
seq_lens_ptr
+
seq_idx
)
# alibi slope for this head
if
USE_ALIBI_SLOPES
:
alibi_slope
=
tl
.
load
(
alibi_slopes_ptr
+
query_head_idx
)
num_blocks
=
cdiv_fn
(
seq_len
,
BLOCK_SIZE
)
# iterate through tiles
for
j
in
range
(
0
,
num_blocks
):
physical_block_idx
=
tl
.
load
(
block_tables_ptr
+
block_table_offset
+
j
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
offs_d
=
tl
.
arange
(
0
,
HEAD_SIZE_PADDED
)
v_offset
=
(
physical_block_idx
*
stride_v_cache_0
+
kv_head_idx
*
stride_v_cache_1
+
offs_d
[:,
None
]
*
stride_v_cache_2
+
offs_n
[
None
,
:]
*
stride_v_cache_3
)
k_offset
=
(
physical_block_idx
*
stride_k_cache_0
+
kv_head_idx
*
stride_k_cache_1
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_2
+
offs_n
[
None
,
:]
*
stride_k_cache_3
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_4
)
# K : (HEAD_SIZE, BLOCK_SIZE)
K_load
=
tl
.
load
(
key_cache_ptr
+
k_offset
,
mask
=
dim_mask
[:,
None
],
other
=
0.0
)
if
K_load
.
dtype
.
is_fp8
():
K
=
(
K_load
.
to
(
tl
.
float32
)
*
tl
.
load
(
k_scale
)).
to
(
Q
.
dtype
)
else
:
K
=
K_load
# V : (HEAD_SIZE, BLOCK_SIZE)
V_load
=
tl
.
load
(
value_cache_ptr
+
v_offset
,
mask
=
dim_mask
[:,
None
],
other
=
0.0
)
if
V_load
.
dtype
.
is_fp8
():
V
=
(
V_load
.
to
(
tl
.
float32
)
*
tl
.
load
(
v_scale
)).
to
(
Q
.
dtype
)
else
:
V
=
V_load
tmp
=
j
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
boundary
=
tl
.
full
([
BLOCK_SIZE
],
seq_len
,
dtype
=
tl
.
int32
)
mask_new
=
tmp
<
boundary
# S : (BLOCK_SIZE,)
S
=
tl
.
where
(
mask_new
,
0.0
,
float
(
"-inf"
)).
to
(
tl
.
float32
)
S
+=
scale
*
tl
.
sum
(
K
*
Q
[:,
None
],
axis
=
0
)
if
SLIDING_WINDOW
>
0
:
S
=
tl
.
where
((
seq_len
-
1
-
tmp
)
<
SLIDING_WINDOW
,
S
,
-
10000
)
if
USE_ALIBI_SLOPES
:
S
+=
alibi_slope
*
(
tmp
-
seq_len
+
1
)
# compute running maximum
# m_j : (1,)
m_j
=
tl
.
maximum
(
M
,
tl
.
max
(
S
,
axis
=
0
))
# P : (BLOCK_SIZE,)
P
=
tl
.
exp
(
S
-
m_j
)
# l_j : (1,)
l_j
=
tl
.
sum
(
P
,
axis
=
0
)
# alpha : (1, )
alpha
=
tl
.
exp
(
M
-
m_j
)
# acc : (BLOCK_SIZE,)
acc
=
acc
*
alpha
# update constants
L
=
L
*
alpha
+
l_j
M
=
m_j
# acc : (BLOCK_SIZE,)
acc
+=
tl
.
sum
(
V
*
P
[
None
,
:],
axis
=
1
)
# epilogue
acc
=
acc
/
L
output_offset
=
(
cur_batch_in_all_start_index
*
output_stride_0
+
query_head_idx
*
output_stride_1
)
tl
.
store
(
output_ptr
+
output_offset
+
tl
.
arange
(
0
,
HEAD_SIZE_PADDED
),
acc
,
mask
=
dim_mask
)
def
chunked_prefill_paged_decode
(
query
,
key
,
value
,
output
,
kv_cache_dtype
,
key_cache
,
value_cache
,
block_table
,
query_start_loc
,
seq_lens
,
max_query_len
,
k_scale
,
v_scale
,
alibi_slopes
=
None
,
sliding_window
=
None
,
sm_scale
=
None
,
):
if
sm_scale
is
None
:
sm_scale
=
1.0
/
(
query
.
shape
[
1
]
**
0.5
)
use_alibi_slopes
=
alibi_slopes
is
not
None
if
sliding_window
is
None
or
sliding_window
<=
0
:
sliding_window
=
0
if
max_query_len
>
1
:
context_attention_fwd
(
q
=
query
,
k
=
key
,
v
=
value
,
o
=
output
,
kv_cache_dtype
=
kv_cache_dtype
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
b_loc
=
block_table
,
b_start_loc
=
query_start_loc
,
b_seq_len
=
seq_lens
,
max_input_len
=
max_query_len
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
alibi_slopes
=
alibi_slopes
,
sliding_window
=
sliding_window
,
sm_scale
=
sm_scale
,
skip_decode
=
True
,
)
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
len
(
seq_lens
)
num_query_heads
=
query
.
shape
[
1
]
num_queries_per_kv
=
query
.
shape
[
1
]
//
key
.
shape
[
1
]
head_size
=
query
.
shape
[
2
]
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if
"fp8"
in
kv_cache_dtype
:
assert
key_cache
.
dtype
==
torch
.
uint8
assert
value_cache
.
dtype
==
torch
.
uint8
if
kv_cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
):
target_dtype
=
torch
.
float8_e4m3fn
elif
kv_cache_dtype
==
"fp8_e5m2"
:
target_dtype
=
torch
.
float8_e5m2
else
:
raise
ValueError
(
"Unsupported FP8 dtype:"
,
kv_cache_dtype
)
key_cache
=
key_cache
.
view
(
target_dtype
)
value_cache
=
value_cache
.
view
(
target_dtype
)
kernel_paged_attention_2d
[(
num_seqs
,
num_query_heads
,
)](
output_ptr
=
output
,
query_ptr
=
query
,
key_cache_ptr
=
key_cache
,
value_cache_ptr
=
value_cache
,
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seq_lens
,
alibi_slopes_ptr
=
alibi_slopes
,
scale
=
sm_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
num_query_heads
=
num_query_heads
,
num_queries_per_kv
=
num_queries_per_kv
,
block_table_stride
=
block_table
.
stride
(
0
),
query_stride_0
=
query
.
stride
(
0
),
query_stride_1
=
query
.
stride
(
1
),
output_stride_0
=
output
.
stride
(
0
),
output_stride_1
=
output
.
stride
(
1
),
BLOCK_SIZE
=
block_size
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
SLIDING_WINDOW
=
sliding_window
,
x
=
key_cache
.
shape
[
4
],
stride_k_cache_0
=
key_cache
.
stride
(
0
),
stride_k_cache_1
=
key_cache
.
stride
(
1
),
stride_k_cache_2
=
key_cache
.
stride
(
2
),
stride_k_cache_3
=
key_cache
.
stride
(
3
),
stride_k_cache_4
=
key_cache
.
stride
(
4
),
stride_v_cache_0
=
value_cache
.
stride
(
0
),
stride_v_cache_1
=
value_cache
.
stride
(
1
),
stride_v_cache_2
=
value_cache
.
stride
(
2
),
stride_v_cache_3
=
value_cache
.
stride
(
3
),
filter_by_query_len
=
True
,
query_start_len_ptr
=
query_start_loc
,
)
vllm/attention/ops/prefix_prefill.py
View file @
6bd1dd9d
...
...
@@ -64,7 +64,9 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_N
:
tl
.
constexpr
,
SLIDING_WINDOW
:
tl
.
constexpr
,
SKIP_DECODE
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
...
...
@@ -78,6 +80,9 @@ if triton.__version__ >= "2.1.0":
cur_batch_in_all_start_index
)
cur_batch_ctx_len
=
cur_batch_seq_len
-
cur_batch_query_len
if
SKIP_DECODE
and
cur_batch_query_len
==
1
:
return
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc
=
BLOCK_M
*
start_m
...
...
@@ -500,6 +505,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
:
tl
.
constexpr
,
# head size
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_N
:
tl
.
constexpr
,
SKIP_DECODE
:
tl
.
constexpr
,
):
# attn_bias[]
cur_batch
=
tl
.
program_id
(
0
)
...
...
@@ -518,6 +524,9 @@ if triton.__version__ >= "2.1.0":
cur_batch_in_all_start_index
)
cur_batch_ctx_len
=
cur_batch_seq_len
-
cur_batch_query_len
if
SKIP_DECODE
and
cur_batch_query_len
==
1
:
return
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
...
...
@@ -721,7 +730,8 @@ if triton.__version__ >= "2.1.0":
v_scale
:
torch
.
Tensor
,
alibi_slopes
=
None
,
sliding_window
=
None
,
sm_scale
=
None
):
sm_scale
=
None
,
skip_decode
=
False
):
q_dtype_is_f32
=
q
.
dtype
is
torch
.
float32
# need to reduce num. blocks when using fp32
...
...
@@ -823,6 +833,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
SKIP_DECODE
=
skip_decode
,
num_warps
=
NUM_WARPS
,
num_stages
=
1
,
)
...
...
@@ -875,6 +886,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
SLIDING_WINDOW
=
sliding_window
,
SKIP_DECODE
=
skip_decode
,
num_warps
=
NUM_WARPS
,
num_stages
=
1
,
)
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
6bd1dd9d
...
...
@@ -6,8 +6,9 @@ import torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.chunked_prefill_paged_decode
import
(
chunked_prefill_paged_decode
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionMetadata
,
FlashAttentionMetadataBuilder
)
...
...
@@ -156,20 +157,22 @@ class ROCmAttentionImpl(AttentionImpl):
)
# Compute attention and update output up to `num_actual_tokens`.
context_attention_fwd
(
q
=
query
[:
num_actual_tokens
],
k
=
key
[:
num_actual_tokens
],
v
=
value
[:
num_actual_tokens
],
o
=
output
[:
num_actual_tokens
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
b_loc
=
attn_metadata
.
block_table
,
b_start_loc
=
attn_metadata
.
query_start_loc
,
b_seq_len
=
attn_metadata
.
seq_lens
,
max_input_len
=
attn_metadata
.
max_query_len
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
[
0
],
sm_scale
=
self
.
scale
)
chunked_prefill_paged_decode
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
value
=
value
[:
num_actual_tokens
],
output
=
output
[:
num_actual_tokens
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
block_table
=
attn_metadata
.
block_table
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
attn_metadata
.
seq_lens
,
max_query_len
=
attn_metadata
.
max_query_len
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
[
0
],
sm_scale
=
self
.
scale
)
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment